Commit 92aa2fa8 authored by zachteed's avatar zachteed
Browse files

initial commit

parents
import torch
import numpy as np
from collections import OrderedDict
import lietorch
from data_readers.rgbd_utils import compute_distance_matrix_flow
def graph_to_edge_list(graph):
ii, jj, kk = [], [], []
for s, u in enumerate(graph):
for v in graph[u]:
ii.append(u)
jj.append(v)
kk.append(s)
ii = torch.as_tensor(ii).cuda()
jj = torch.as_tensor(jj).cuda()
kk = torch.as_tensor(kk).cuda()
return ii, jj, kk
def keyframe_indicies(graph):
return torch.as_tensor([u for u in graph]).cuda()
def meshgrid(m, n, device='cuda'):
ii, jj = torch.meshgrid(torch.arange(m), torch.arange(n))
return ii.reshape(-1).to(device), jj.reshape(-1).to(device)
class KeyframeGraph:
def __init__(self, images, poses, depths, intrinsics):
self.images = images.cpu()
self.depths = depths.cpu()
self.poses = poses
self.intrinsics = intrinsics
depths = depths[..., 3::8, 3::8].float().cuda()
disps = torch.where(depths>0.1, 1.0/depths, depths)
N = poses.shape[1]
d = compute_distance_matrix_flow(poses, disps, intrinsics / 8.0)
i, j = 0, 0
ixs = [ i ]
while j < N-1:
if d[i, j+1] > 7.5:
ixs += [ j ]
i = j
j += 1
# indicies of keyframes
self.distance_matrix = d[ixs][:,ixs]
self.ixs = np.array(ixs)
self.frame_graph = {}
for i in range(N):
k = np.argmin(np.abs(i - self.ixs))
j = self.ixs[k]
self.frame_graph[i] = (k, poses[:,i] * poses[:,j].inv())
def get_keyframes(self):
ix = torch.as_tensor(self.ixs).cuda()
return self.images[:,ix], self.poses[:,ix], self.depths[:,ix], self.intrinsics[:,ix]
def get_graph(self, num=-1, thresh=24.0, r=2):
d = self.distance_matrix.copy()
N = d.shape[0]
if num < 0:
num = N
graph = OrderedDict()
for i in range(N):
graph[i] = [j for j in range(N) if i!=j and abs(i-j) <= 2]
for i in range(N):
for j in range(i-r, i+r+1):
if j >= 0 and j < N:
d[i,j] = np.inf
for _ in range(num):
ix = np.argmin(d)
i, j = ix // N, ix % N
if d[i,j] < thresh:
graph[i].append(j)
for ii in range(i-r, i+r+1):
for jj in range(j-r, j+r+1):
if ii>=0 and jj>=0 and ii<N and jj<N:
d[ii,jj] = np.inf
else:
break
return graph
def get_poses(self, keyframe_poses):
poses_list = []
for i in range(self.poses.shape[1]):
k, dP = self.frame_graph[i]
poses_list += [dP * keyframe_poses[:,k]]
return lietorch.stack(poses_list, 1)
\ No newline at end of file
import numpy as np
import torch
from lietorch import SO3, SE3, Sim3
from .graph_utils import graph_to_edge_list
def pose_metrics(dE):
""" Translation/Rotation/Scaling metrics from Sim3 """
t, q, s = dE.data.split([3, 4, 1], -1)
ang = SO3(q).log().norm(dim=-1)
# convert radians to degrees
r_err = (180 / np.pi) * ang
t_err = t.norm(dim=-1)
s_err = (s - 1.0).abs()
return r_err, t_err, s_err
def geodesic_loss(Ps, Gs, graph, gamma=0.9):
""" Loss function for training network """
# relative pose
ii, jj, kk = graph_to_edge_list(graph)
dP = Ps[:,jj] * Ps[:,ii].inv()
n = len(Gs)
geodesic_loss = 0.0
for i in range(n):
w = gamma ** (n - i - 1)
dG = Gs[i][:,jj] * Gs[i][:,ii].inv()
# pose error
d = (dG * dP.inv()).log()
if isinstance(dG, SE3):
tau, phi = d.split([3,3], dim=-1)
geodesic_loss += w * (
tau.norm(dim=-1).mean() +
phi.norm(dim=-1).mean())
elif isinstance(dG, Sim3):
tau, phi, sig = d.split([3,3,1], dim=-1)
geodesic_loss += w * (
tau.norm(dim=-1).mean() +
phi.norm(dim=-1).mean() +
0.05 * sig.norm(dim=-1).mean())
dE = Sim3(dG * dP.inv()).detach()
r_err, t_err, s_err = pose_metrics(dE)
metrics = {
'r_error': r_err.mean().item(),
't_error': t_err.mean().item(),
's_error': s_err.mean().item(),
}
return geodesic_loss, metrics
def residual_loss(residuals, gamma=0.9):
""" loss on system residuals """
residual_loss = 0.0
n = len(residuals)
for i in range(n):
w = gamma ** (n - i - 1)
residual_loss += w * residuals[i].abs().mean()
return residual_loss, {'residual': residual_loss.item()}
import torch
import torch.nn.functional as F
from lietorch import SE3, Sim3
MIN_DEPTH = 0.1
def extract_intrinsics(intrinsics):
return intrinsics[...,None,None,:].unbind(dim=-1)
def iproj(disps, intrinsics):
""" pinhole camera inverse projection """
ht, wd = disps.shape[2:]
fx, fy, cx, cy = extract_intrinsics(intrinsics)
y, x = torch.meshgrid(
torch.arange(ht).to(disps.device).float(),
torch.arange(wd).to(disps.device).float())
i = torch.ones_like(disps)
X = (x - cx) / fx
Y = (y - cy) / fy
return torch.stack([X, Y, i, disps], dim=-1)
def proj(Xs, intrinsics, jacobian=False):
""" pinhole camera projection """
fx, fy, cx, cy = extract_intrinsics(intrinsics)
X, Y, Z, D = Xs.unbind(dim=-1)
d = torch.where(Z.abs() < 0.001, torch.zeros_like(Z), 1.0/Z)
x = fx * (X * d) + cx
y = fy * (Y * d) + cy
coords = torch.stack([x,y, D*d], dim=-1)
if jacobian:
B, N, H, W = d.shape
o = torch.zeros_like(d)
proj_jac = torch.stack([
fx*d, o, -fx*X*d*d, o,
o, fy*d, -fy*Y*d*d, o,
o, o, -D*d*d, d,
], dim=-1).view(B, N, H, W, 3, 4)
return coords, proj_jac
return coords, None
def actp(Gij, X0, jacobian=False):
""" action on point cloud """
X1 = Gij[:,:,None,None] * X0
if jacobian:
X, Y, Z, d = X1.unbind(dim=-1)
o = torch.zeros_like(d)
B, N, H, W = d.shape
if isinstance(Gij, SE3):
Ja = torch.stack([
d, o, o, o, Z, -Y,
o, d, o, -Z, o, X,
o, o, d, Y, -X, o,
o, o, o, o, o, o,
], dim=-1).view(B, N, H, W, 4, 6)
elif isinstance(Gij, Sim3):
Ja = torch.stack([
d, o, o, o, Z, -Y, X,
o, d, o, -Z, o, X, Y,
o, o, d, Y, -X, o, Z,
o, o, o, o, o, o, o
], dim=-1).view(B, N, H, W, 4, 7)
return X1, Ja
return X1, None
def projective_transform(poses, depths, intrinsics, ii, jj, jacobian=False):
""" map points from ii->jj """
# inverse project (pinhole)
X0 = iproj(depths[:,ii], intrinsics[:,ii])
# transform
Gij = poses[:,jj] * poses[:,ii].inv()
X1, Ja = actp(Gij, X0, jacobian=jacobian)
# project (pinhole)
x1, Jp = proj(X1, intrinsics[:,jj], jacobian=jacobian)
# exclude points too close to camera
valid = ((X1[...,2] > MIN_DEPTH) & (X0[...,2] > MIN_DEPTH)).float()
valid = valid.unsqueeze(-1)
if jacobian:
Jj = torch.matmul(Jp, Ja)
Ji = -Gij[:,:,None,None,None].adjT(Jj)
return x1, valid, (Ji, Jj)
return x1, valid
def induced_flow(poses, disps, intrinsics, ii, jj):
""" optical flow induced by camera motion """
ht, wd = disps.shape[2:]
y, x = torch.meshgrid(
torch.arange(ht).to(disps.device).float(),
torch.arange(wd).to(disps.device).float())
coords0 = torch.stack([x, y], dim=-1)
coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj)
return coords1[...,:2] - coords0, valid
import torch
import torch.nn.functional as F
def _bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def bilinear_sampler(img, coords):
""" Wrapper for bilinear sampler for inputs with extra batch dimensions """
unflatten = False
if len(img.shape) == 5:
unflatten = True
b, n, c, h, w = img.shape
img = img.view(b*n, c, h, w)
coords = coords.view(b*n, h, w, 2)
img1 = _bilinear_sampler(img, coords)
if unflatten:
return img1.view(b, n, c, h, w)
return img1
def sample_depths(depths, coords):
batch, num, ht, wd = depths.shape
depths = depths.view(batch, num, 1, ht, wd)
coords = coords.view(batch, num, ht, wd, 2)
depths_proj = bilinear_sampler(depths, coords)
return depths_proj.view(batch, num, ht, wd, 1)
import torch
from torch.utils.tensorboard import SummaryWriter
SUM_FREQ = 100
class Logger:
def __init__(self, name, scheduler):
self.total_steps = 0
self.running_loss = {}
self.writer = None
self.name = name
self.scheduler = scheduler
def _print_training_status(self):
if self.writer is None:
self.writer = SummaryWriter('runs/%s' % self.name)
print([k for k in self.running_loss])
lr = self.scheduler.get_lr().pop()
metrics_data = [self.running_loss[k]/SUM_FREQ for k in self.running_loss.keys()]
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, lr)
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
# print the training status
print(training_str + metrics_str)
for key in self.running_loss:
val = self.running_loss[key] / SUM_FREQ
self.writer.add_scalar(key, val, self.total_steps)
self.running_loss[key] = 0.0
def push(self, metrics):
for key in metrics:
if key not in self.running_loss:
self.running_loss[key] = 0.0
self.running_loss[key] += metrics[key]
if self.total_steps % SUM_FREQ == SUM_FREQ-1:
self._print_training_status()
self.running_loss = {}
self.total_steps += 1
import torch
import torch.nn as nn
import torch.nn.functional as F
GRAD_CLIP = .01
class GradClip(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, grad_x):
o = torch.zeros_like(grad_x)
grad_x = torch.where(grad_x.abs()>GRAD_CLIP, o, grad_x)
grad_x = torch.where(torch.isnan(grad_x), o, grad_x)
return grad_x
class GradientClip(nn.Module):
def __init__(self):
super(GradientClip, self).__init__()
def forward(self, x):
return GradClip.apply(x)
\ No newline at end of file
import torch
import torch.nn.functional as F
from geom.sampler_utils import bilinear_sampler
import lietorch_extras
class CorrSampler(torch.autograd.Function):
@staticmethod
def forward(ctx, volume, coords, radius):
ctx.save_for_backward(volume,coords)
ctx.radius = radius
corr, = lietorch_extras.corr_index_forward(volume, coords, radius)
return corr
@staticmethod
def backward(ctx, grad_output):
volume, coords = ctx.saved_tensors
grad_output = grad_output.contiguous()
grad_volume, = lietorch_extras.corr_index_backward(volume, coords, grad_output, ctx.radius)
return grad_volume, None, None
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, num, h1, w1, h2, w2 = corr.shape
corr = corr.reshape(batch*num*h1*w1, 1, h2, w2)
for i in range(self.num_levels):
self.corr_pyramid.append(
corr.view(batch*num, h1, w1, h2//2**i, w2//2**i))
corr = F.avg_pool2d(corr, 2, stride=2)
def __call__(self, coords):
out_pyramid = []
batch, num, ht, wd, _ = coords.shape
coords = coords.permute(0,1,4,2,3)
coords = coords.contiguous().view(batch*num, 2, ht, wd)
for i in range(self.num_levels):
corr = CorrSampler.apply(self.corr_pyramid[i], coords/2**i, self.radius)
out_pyramid.append(corr.view(batch, num, -1, ht, wd))
return torch.cat(out_pyramid, dim=2)
def append(self, other):
for i in range(self.num_levels):
self.corr_pyramid[i] = torch.cat([self.corr_pyramid[i], other.corr_pyramid[i]], 0)
def remove(self, ix):
for i in range(self.num_levels):
self.corr_pyramid[i] = self.corr_pyramid[i][ix].contiguous()
@staticmethod
def corr(fmap1, fmap2):
""" all-pairs correlation """
batch, num, dim, ht, wd = fmap1.shape
fmap1 = fmap1.reshape(batch*num, dim, ht*wd) / 4.0
fmap2 = fmap2.reshape(batch*num, dim, ht*wd) / 4.0
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
return corr.view(batch, num, ht, wd, ht, wd)
class CorrLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, fmap1, fmap2, coords, r):
ctx.r = r
fmap1 = fmap1.contiguous()
fmap2 = fmap2.contiguous()
coords = coords.contiguous()
ctx.save_for_backward(fmap1, fmap2, coords)
corr, = lietorch_extras.altcorr_forward(fmap1, fmap2, coords, ctx.r)
return corr
@staticmethod
def backward(ctx, grad_corr):
fmap1, fmap2, coords = ctx.saved_tensors
grad_corr = grad_corr.contiguous()
fmap1_grad, fmap2_grad, coords_grad = \
lietorch_extras.altcorr_backward(fmap1, fmap2, coords, grad_corr, ctx.r)
return fmap1_grad, fmap2_grad, coords_grad, None
class AltCorrBlock:
def __init__(self, fmaps, inds, num_levels=4, radius=3):
self.num_levels = num_levels
self.radius = radius
self.inds = inds
B, N, C, H, W = fmaps.shape
fmaps = fmaps.view(B*N, C, H, W)
self.pyramid = []
for i in range(self.num_levels):
sz = (B, N, H//2**i, W//2**i, C)
fmap_lvl = fmaps.permute(0, 2, 3, 1)
self.pyramid.append(fmap_lvl.reshape(*sz))
fmaps = F.avg_pool2d(fmaps, 2, stride=2)
def corr_fn(self, coords, ii, jj):
B, N, H, W, S, _ = coords.shape
coords = coords.permute(0, 1, 4, 2, 3, 5)
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][:, ii]
fmap2_i = self.pyramid[i][:, jj]
coords_i = (coords / 2**i).reshape(B*N, S, H, W, 2).contiguous()
fmap1_i = fmap1_i.reshape((B*N,) + fmap1_i.shape[2:])
fmap2_i = fmap2_i.reshape((B*N,) + fmap2_i.shape[2:])
corr = CorrLayer.apply(fmap1_i, fmap2_i, coords_i, self.radius)
corr = corr.view(B, N, S, -1, H, W).permute(0, 1, 3, 4, 5, 2)
corr_list.append(corr)
corr = torch.cat(corr_list, dim=2)
return corr / 16.0
def __call__(self, coords, ii, jj):
squeeze_output = False
if len(coords.shape) == 5:
coords = coords.unsqueeze(dim=-2)
squeeze_output = True
corr = self.corr_fn(coords, ii, jj)
if squeeze_output:
corr = corr.squeeze(dim=-1)
return corr.contiguous()
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
DIM=32
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
self.multidim = multidim
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(DIM)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(DIM)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = DIM
self.layer1 = self._make_layer(DIM, stride=1)
self.layer2 = self._make_layer(2*DIM, stride=2)
self.layer3 = self._make_layer(4*DIM, stride=2)
# output convolution
self.conv2 = nn.Conv2d(4*DIM, output_dim, kernel_size=1)
if self.multidim:
self.layer4 = self._make_layer(256, stride=2)
self.layer5 = self._make_layer(512, stride=2)
self.in_planes = 256
self.layer6 = self._make_layer(256, stride=1)
self.in_planes = 128
self.layer7 = self._make_layer(128, stride=1)
self.up1 = nn.Conv2d(512, 256, 1)
self.up2 = nn.Conv2d(256, 128, 1)
self.conv3 = nn.Conv2d(128, output_dim, kernel_size=1)
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
b, n, c1, h1, w1 = x.shape
x = x.view(b*n, c1, h1, w1)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
_, c2, h2, w2 = x.shape
return x.view(b, n, c2, h2, w2)
class BasicEncoder16(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False):
super(BasicEncoder16, self).__init__()
self.norm_fn = norm_fn
self.multidim = multidim
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(DIM)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(DIM)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = DIM
self.layer1 = self._make_layer(DIM, stride=1)
self.layer2 = self._make_layer(2*DIM, stride=2)
self.layer3 = self._make_layer(4*DIM, stride=2)
self.layer4 = self._make_layer(4*DIM, stride=2)
# output convolution
self.conv2 = nn.Conv2d(4*DIM, output_dim, kernel_size=1)
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
else:
self.dropout = None
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
b, n, c1, h1, w1 = x.shape
x = x.view(b*n, c1, h1, w1)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.conv2(x)
_, c2, h2, w2 = x.shape
return x.view(b, n, c2, h2, w2)
import torch
import torch.nn as nn
class ConvGRU(nn.Module):
def __init__(self, h_planes=128, i_planes=128):
super(ConvGRU, self).__init__()
self.do_checkpoint = False
self.convz = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1)
self.convr = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1)
self.convq = nn.Conv2d(h_planes+i_planes, h_planes, 3, padding=1)
def forward(self, net, *inputs):
inp = torch.cat(inputs, dim=1)
net_inp = torch.cat([net, inp], dim=1)
z = torch.sigmoid(self.convz(net_inp))
r = torch.sigmoid(self.convr(net_inp))
q = torch.tanh(self.convq(torch.cat([r*net, inp], dim=1)))
net = (1-z) * net + z * q
return net
import torch
import torch.nn as nn
# Unet model from https://github.com/usuyama/pytorch-unet
GRAD_CLIP = .01
class GradClip(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, grad_x):
o = torch.zeros_like(grad_x)
grad_x = torch.where(grad_x.abs()>GRAD_CLIP, o, grad_x)
grad_x = torch.where(torch.isnan(grad_x), o, grad_x)
return grad_x
def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 5, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 5, padding=2),
nn.ReLU(inplace=True)
)
class UNet(nn.Module):
def __init__(self):
super().__init__()
self.dconv_down1 = double_conv(128, 128)
self.dconv_down2 = double_conv(128, 256)
self.dconv_down3 = double_conv(256, 256)
# self.dconv_down4 = double_conv(256, 512)
self.maxpool = nn.AvgPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.dconv_up3 = double_conv(256 + 256, 256)
self.dconv_up2 = double_conv(256 + 256, 128)
self.dconv_up1 = double_conv(128 + 128, 128)
self.conv_r = nn.Conv2d(128, 3, 1)
self.conv_w = nn.Conv2d(128, 3, 1)
def forward(self, x):
b, n, c, ht, wd = x.shape
x = x.view(b*n, c, ht, wd)
conv1 = self.dconv_down1(x)
x = self.maxpool(conv1)
conv2 = self.dconv_down2(x)
x = self.maxpool(conv2)
conv3 = self.dconv_down3(x)
x = torch.cat([x, conv3], dim=1)
x = self.dconv_up3(x)
x = self.upsample(x)
x = torch.cat([x, conv2], dim=1)
x = self.dconv_up2(x)
x = self.upsample(x)
x = torch.cat([x, conv1], dim=1)
x = self.dconv_up1(x)
r = self.conv_r(x)
w = self.conv_w(x)
w = torch.sigmoid(w)
w = w.view(b, n, 3, ht, wd).permute(0,1,3,4,2)
r = r.view(b, n, 3, ht, wd).permute(0,1,3,4,2)
# w = GradClip.apply(w)
# r = GradClip.apply(r)
return r, w
\ No newline at end of file
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from networks.modules.extractor import BasicEncoder
from networks.modules.corr import CorrBlock
from networks.modules.gru import ConvGRU
from networks.modules.clipping import GradientClip
from lietorch import SE3, Sim3
from geom.ba import MoBA
import geom.projective_ops as pops
from geom.sampler_utils import bilinear_sampler, sample_depths
from geom.graph_utils import graph_to_edge_list, keyframe_indicies
class UpdateModule(nn.Module):
def __init__(self, args):
super(UpdateModule, self).__init__()
self.args = args
cor_planes = 4 * (2*3 + 1)**2 + 1
self.corr_encoder = nn.Sequential(
nn.Conv2d(cor_planes, 128, 1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True))
self.flow_encoder = nn.Sequential(
nn.Conv2d(3, 128, 7, padding=3),
nn.ReLU(inplace=True),
nn.Conv2d(128, 64, 3, padding=1),
nn.ReLU(inplace=True))
self.weight = nn.Sequential(
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 3, 3, padding=1),
GradientClip(),
nn.Sigmoid())
self.delta = nn.Sequential(
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 3, 3, padding=1),
GradientClip())
self.gru = ConvGRU(128, 128+128+64)
def forward(self, net, inp, corr, flow):
""" RaftSLAM update operator """
batch, num, ch, ht, wd = net.shape
output_dim = (batch, num, -1, ht, wd)
net = net.view(batch*num, -1, ht, wd)
inp = inp.view(batch*num, -1, ht, wd)
corr = corr.view(batch*num, -1, ht, wd)
flow = flow.view(batch*num, -1, ht, wd)
corr = self.corr_encoder(corr)
flow = self.flow_encoder(flow)
net = self.gru(net, inp, corr, flow)
### update variables ###
delta = self.delta(net).view(*output_dim)
weight = self.weight(net).view(*output_dim)
delta = delta.permute(0,1,3,4,2).contiguous()
weight = weight.permute(0,1,3,4,2).contiguous()
net = net.view(*output_dim)
return net, delta, weight
class RaftSLAM(nn.Module):
def __init__(self, args):
super(RaftSLAM, self).__init__()
self.args = args
self.fnet = BasicEncoder(output_dim=128, norm_fn='instance')
self.cnet = BasicEncoder(output_dim=256, norm_fn='none')
self.update = UpdateModule(args)
def extract_features(self, images):
""" run feeature extraction networks """
fmaps = self.fnet(images)
net = self.cnet(images)
net, inp = net.split([128,128], dim=2)
net = torch.tanh(net)
inp = torch.relu(inp)
return fmaps, net, inp
def forward(self, Gs, images, depths, intrinsics, graph=None, num_steps=12):
""" Estimates SE3 or Sim3 between pair of frames """
u = keyframe_indicies(graph)
ii, jj, kk = graph_to_edge_list(graph)
depths = depths[:, :, 3::8, 3::8]
intrinsics = intrinsics / 8
mask = (depths > 0.1).float()
disps = torch.where(depths>0.1, 1.0/depths, depths)
fmaps, net, inp = self.extract_features(images)
net, inp = net[:,ii], inp[:,ii]
corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3)
coords, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
residual = torch.zeros_like(coords[...,:2])
Gs_list, coords_list, residual_list = [], [], []
for step in range(num_steps):
Gs = Gs.detach()
coords = coords.detach()
residual = residual.detach()
corr = corr_fn(coords[...,:2])
flow = residual.permute(0,1,4,2,3).clamp(-32.0, 32.0)
corr = torch.cat([corr, mask[:,ii,None]], dim=2)
flow = torch.cat([flow, mask[:,ii,None]], dim=2)
net, delta, weight = self.update(net, inp, corr, flow)
target = coords + delta
weight[...,2] = 0.0
for i in range(3):
Gs = MoBA(target, weight, Gs, disps, intrinsics, ii, jj)
coords, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
residual = (target - coords)[...,:2]
Gs_list.append(Gs)
coords_list.append(target)
valid_mask = valid_mask * mask[:,ii].unsqueeze(-1)
residual_list.append(valid_mask * residual)
return Gs_list, residual_list
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from networks.modules.extractor import BasicEncoder
from networks.modules.corr import CorrBlock
from networks.modules.gru import ConvGRU
from networks.modules.clipping import GradientClip
from lietorch import SE3, Sim3
from geom.ba import MoBA
import geom.projective_ops as pops
from geom.sampler_utils import bilinear_sampler, sample_depths
from geom.graph_utils import graph_to_edge_list, keyframe_indicies
class UpdateModule(nn.Module):
def __init__(self, args):
super(UpdateModule, self).__init__()
self.args = args
cor_planes = 4 * (2*3 + 1)**2
self.encoder = nn.Sequential(
nn.Conv2d(cor_planes, 128, 1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True))
self.weight = nn.Sequential(
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 3, 3, padding=1),
GradientClip(),
nn.Sigmoid())
self.delta = nn.Sequential(
nn.Conv2d(128, 128, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 3, 3, padding=1),
GradientClip())
self.gru = ConvGRU(128, 128+128+1)
def forward(self, net, inp, corr, dz):
""" update operator """
batch, num, ch, ht, wd = net.shape
output_dim = (batch, num, -1, ht, wd)
net = net.view(batch*num, -1, ht, wd)
inp = inp.view(batch*num, -1, ht, wd)
corr = corr.view(batch*num, -1, ht, wd)
dz = dz.view(batch*num, 1, ht, wd)
corr = self.encoder(corr)
net = self.gru(net, inp, corr, dz)
### update variables ###
delta = self.delta(net).view(*output_dim)
weight = self.weight(net).view(*output_dim)
delta = delta.permute(0,1,3,4,2).contiguous()
weight = weight.permute(0,1,3,4,2).contiguous()
net = net.view(*output_dim)
return net, delta, weight
class Sim3Net(nn.Module):
def __init__(self, args):
super(Sim3Net, self).__init__()
self.args = args
self.fnet = BasicEncoder(output_dim=128, norm_fn='instance')
self.cnet = BasicEncoder(output_dim=256, norm_fn='none')
self.update = UpdateModule(args)
def extract_features(self, images):
""" run feeature extraction networks """
fmaps = self.fnet(images)
net = self.cnet(images)
net, inp = net.split([128,128], dim=2)
net = torch.tanh(net)
inp = torch.relu(inp)
return fmaps, net, inp
def forward(self, Gs, images, depths, intrinsics, graph=None, num_steps=12):
""" Estimates SE3 or Sim3 between pair of frames """
if graph is None:
graph = OrderedDict()
graph[0] = [1]
graph[1] = [0]
u = keyframe_indicies(graph)
ii, jj, kk = graph_to_edge_list(graph)
# use inverse depth parameterization
depths = depths.clamp(min=0.1, max=1000.0)
disps = 1.0 / depths[:, :, 3::8, 3::8]
intrinsics = intrinsics / 8.0
fmaps, net, inp = self.extract_features(images)
corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3)
Gs_list, coords_list, residual_list = [], [], []
for step in range(num_steps):
Gs = Gs.detach()
coords1_xyz, _ = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
coords1, zinv_proj = coords1_xyz.split([2,1], dim=-1)
zinv = sample_depths(disps[:,jj], coords1)
dz = (zinv - zinv_proj).clamp(-1.0, 1.0)
corr = corr_fn(coords1)
net, delta, weight = self.update(net, inp, corr, dz)
target = coords1_xyz + delta
for i in range(3):
Gs = MoBA(target, weight, Gs, disps, intrinsics, ii, jj)
coords1_xyz, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
residual = valid_mask * (target - coords1_xyz)
Gs_list.append(Gs)
coords_list.append(target)
residual_list.append(residual)
return Gs_list, residual_list
import numpy as np
import torch
from collections import OrderedDict
from torch.cuda.amp import autocast
import matplotlib.pyplot as plt
import lietorch
from lietorch import SE3, Sim3
from geom.ba import MoBA
from .modules.corr import CorrBlock, AltCorrBlock
from .rslam import RaftSLAM
import geom.projective_ops as pops
from geom.sampler_utils import bilinear_sampler
from geom.graph_utils import KeyframeGraph, graph_to_edge_list
def meshgrid(m, n, device='cuda'):
ii, jj = torch.meshgrid(torch.arange(m), torch.arange(n))
return ii.reshape(-1).to(device), jj.reshape(-1).to(device)
def normalize_images(images):
images = images[:, :, [2,1,0]]
mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device)
std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device)
return (images/255.0).sub_(mean[:, None, None]).div_(std[:, None, None])
class FactorGraph:
def __init__(self, hidden=None, inputs=None, residu=None, ii=None, jj=None):
self.hidden = hidden
self.inputs = inputs
self.residu = residu
self.ii = ii
self.jj = jj
def __iadd__(self, other):
if self.hidden is None:
self.hidden = other.hidden
self.inputs = other.inputs
self.residu = other.residu
self.ii = other.ii
self.jj = other.jj
else:
self.hidden = torch.cat([self.hidden, other.hidden], 1)
self.inputs = torch.cat([self.inputs, other.inputs], 1)
self.residu = torch.cat([self.residu, other.residu], 1)
self.ii = torch.cat([self.ii, other.ii], 0)
self.jj = torch.cat([self.jj, other.jj], 0)
return self
def rm(self, keep):
self.hidden = self.hidden[:,keep]
self.inputs = self.inputs[:,keep]
self.residu = self.residu[:,keep]
self.ii = self.ii[keep]
self.jj = self.jj[keep]
class SLAMSystem(RaftSLAM):
def __init__(self, args):
super(SLAMSystem, self).__init__(args)
self.mem = 32
self.num_keyframes = 5
self.frontend = None
self.factors = FactorGraph()
self.count = 0
self.fixed_poses = 1
self.images_list = []
self.depths_list = []
self.intrinsics_list = []
def initialize(self, ht, wd):
""" initialize slam buffers """
self.ht, self.wd = ht, wd
ht, wd = ht // 8, wd // 8
self.fmaps = torch.zeros(1, self.mem, 128, ht, wd, device='cuda', dtype=torch.half)
self.nets = torch.zeros(1, self.mem, 128, ht, wd, device='cuda', dtype=torch.half)
self.inps = torch.zeros(1, self.mem, 128, ht, wd, device='cuda', dtype=torch.half)
self.poses = SE3.Identity(1, 2048, device='cuda')
self.disps = torch.ones(1, 2048, ht, wd, device='cuda')
self.intrinsics = torch.zeros(1, 2048, 4, device='cuda')
self.tstamps = torch.zeros(2048, dtype=torch.long)
def set_frontend(self, frontend):
self.frontend = frontend
def add_point_cloud(self, index, image, pose, depth, intrinsics, s=8):
""" add point cloud to visualization """
if self.frontend is None:
return -1
image = image[...,s//2::s,s//2::s]
depth = depth[...,s//2::s,s//2::s]
intrinsics = intrinsics / s
# backproject
points = pops.iproj(1.0/depth[None], intrinsics[None])
points = points[...,:3] / points[...,[3]]
points = points.reshape(-1, 3)
valid = (depth > 0).reshape(-1)
colors = image.reshape(3,-1).t() / 255.0
point_data = points[valid].cpu().numpy()
color_data = colors[valid].cpu().numpy()
color_data = color_data[:, [2,1,0]]
pose_data = pose.inv()[0].data
self.frontend.update_pose(index, pose_data)
self.frontend.update_points(index, point_data, color_data)
def get_keyframes(self):
""" return keyframe poses and timestamps """
return self.poses[0, :self.count], self.tstamps[:self.count]
def raw_poses(self):
return self.poses[0, :self.count].inv().data
def add_keyframe(self, tstamp, image, depth, intrinsics):
""" add keyframe to factor graph """
if self.count == 0:
ht, wd = image.shape[3:]
self.initialize(ht, wd)
inputs = normalize_images(image)
with autocast(enabled=True):
fmaps, net, inp = self.extract_features(inputs)
ix = self.count % self.mem
self.fmaps[:, ix] = fmaps.squeeze(1)
self.nets[:, ix] = net.squeeze(1)
self.inps[:, ix] = inp.squeeze(1)
self.tstamps[self.count] = tstamp
self.intrinsics[:, self.count] = intrinsics / 8.0
disp = torch.where(depth > 0, 1.0/depth, depth)
self.disps[:, self.count] = disp[:,3::8,3::8]
pose = self.poses[:, self.count-1]
self.add_point_cloud(self.count, image, pose, depth, intrinsics)
self.count += 1
def get_node_attributes(self, index):
index = index % self.mem
return self.fmaps[:, index], self.nets[:, index], self.inps[:, index]
def add_factors(self, ii, jj):
""" add factors to slam graph """
fmaps, hidden, inputs = self.get_node_attributes(ii)
residu_shape = (1, ii.shape[0], self.ht//8, self.wd//8, 2)
residu = torch.zeros(*residu_shape).cuda()
self.factors += FactorGraph(hidden, inputs, residu, ii, jj)
def transform_project(self, ii, jj, **kwargs):
""" helper function, compute project transform """
return pops.projective_transform(self.poses, self.disps, self.intrinsics, ii, jj, **kwargs)
def moba(self, num_steps=5, is_init=False):
""" motion only bundle adjustment """
ii, jj = self.factors.ii, self.factors.jj
ixs = torch.cat([ii, jj], 0)
with autocast(enabled=True):
fmap1 = self.fmaps[:, ii % self.mem]
fmap2 = self.fmaps[:, jj % self.mem]
poses = self.poses[:, :jj.max()+1]
corr_fn = CorrBlock(fmap1, fmap2, num_levels=4, radius=3)
mask = (self.disps[:,ii] > 0.01).float()
with autocast(enabled=False):
coords, valid_mask = pops.projective_transform(poses, self.disps, self.intrinsics, ii, jj)
for i in range(num_steps):
corr = corr_fn(coords[...,:2])
corr = torch.cat([corr, mask[:,:,None]], dim=2)
with autocast(enabled=False):
flow = self.factors.residu.permute(0,1,4,2,3).clamp(-32.0, 32.0)
flow = torch.cat([flow, mask[:,:,None]], dim=2)
self.factors.hidden, delta, weight = \
self.update(self.factors.hidden, self.factors.inputs, corr, flow)
with autocast(enabled=False):
target = coords + delta
weight[...,2] = 0.0
for i in range(3):
poses = MoBA(target, weight, poses, self.disps,
self.intrinsics, ii, jj, self.fixed_poses)
coords, valid_mask = pops.projective_transform(poses, self.disps, self.intrinsics, ii, jj)
self.factors.residu = (target - coords)[...,:2]
self.poses[:, :jj.max()+1] = poses
# update visualization
if self.frontend is not None:
for ix in ixs.cpu().numpy():
self.frontend.update_pose(ix, self.poses[:,ix].inv()[0].data)
def track(self, tstamp, image, depth, intrinsics):
""" main thread """
self.images_list.append(image)
self.depths_list.append(depth)
self.intrinsics_list.append(intrinsics)
# collect frames for initialization
if self.count < self.num_keyframes:
self.add_keyframe(tstamp, image, depth, intrinsics)
if self.count == self.num_keyframes:
ii, jj = meshgrid(self.num_keyframes, self.num_keyframes)
keep = ((ii - jj).abs() > 0) & (((ii - jj).abs() <= 3))
self.add_factors(ii[keep], jj[keep])
self.moba(num_steps=8, is_init=True)
else:
self.poses[:,self.count] = self.poses[:,self.count-1]
self.add_keyframe(tstamp, image, depth, intrinsics)
N = self.count
ii = torch.as_tensor([N-3, N-2, N-1, N-1, N-1], device='cuda')
jj = torch.as_tensor([N-1, N-1, N-2, N-3, N-4], device='cuda')
self.add_factors(ii, jj)
self.moba(num_steps=4)
self.fixed_poses += 1
self.factors.rm(self.factors.ii + 2 >= self.fixed_poses)
def forward(self, poses, images, depths, intrinsics, num_steps=12):
""" Estimates SE3 or Sim3 between pair of frames """
keyframe_graph = KeyframeGraph(images, poses, depths, intrinsics)
images, Gs, depths, intrinsics = keyframe_graph.get_keyframes()
images = images.cuda()
depths = depths.cuda()
if self.frontend is not None:
self.frontend.reset()
for i, ix in enumerate(keyframe_graph.ixs):
self.add_point_cloud(ix, images[:,i], Gs[:,i], depths[:,i], intrinsics[:,i], s=4)
for i in range(poses.shape[1]):
self.frontend.update_pose(i, poses[:,i].inv()[0].data)
graph = keyframe_graph.get_graph()
ii, jj, kk = graph_to_edge_list(graph)
ixs = torch.cat([ii, jj], 0)
images = normalize_images(images.cuda())
depths = depths[:, :, 3::8, 3::8].cuda()
mask = (depths > 0.1).float()
disps = torch.where(depths>0.1, 1.0/depths, depths)
intrinsics = intrinsics / 8
with autocast(True):
fmaps, net, inp = self.extract_features(images)
net = net[:,ii]
# alternate corr implementation uses less memory but 4x slower
corr_fn = AltCorrBlock(fmaps.float(), (ii, jj), num_levels=4, radius=3)
# corr_fn = CorrBlock(fmaps[:,ii], fmaps[:,jj], num_levels=4, radius=3)
with autocast(False):
coords, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
residual = torch.zeros_like(coords[...,:2])
for step in range(num_steps):
print("Global refinement iteration #{}".format(step))
net_list = []
targets_list = []
weights_list = []
s = 64
for i in range(0, ii.shape[0], s):
ii1 = ii[i:i+s]
jj1 = jj[i:i+s]
corr1 = corr_fn(coords[:,i:i+s,:,:,:2], ii1, jj1)
flow1 = residual[:, i:i+s].permute(0,1,4,2,3).clamp(-32.0, 32.0)
corr1 = torch.cat([corr1, mask[:,ii1,None]], dim=2)
flow1 = torch.cat([flow1, mask[:,ii1,None]], dim=2)
net1, delta, weight = self.update(net[:,i:i+s], inp[:,ii1], corr1, flow1)
net[:,i:i+s] = net1
targets_list += [ coords[:,i:i+s] + delta.float() ]
weights_list += [ weight.float() * torch.as_tensor([1.0, 1.0, 0.0]).cuda() ]
target = torch.cat(targets_list, 1)
weight = torch.cat(weights_list, 1)
with autocast(False):
for i in range(3):
Gs = MoBA(target, weight, Gs, disps, intrinsics, ii, jj, lm=0.00001, ep=.01)
coords, valid_mask = pops.projective_transform(Gs, disps, intrinsics, ii, jj)
residual = (target - coords)[...,:2]
poses = keyframe_graph.get_poses(Gs)
if self.frontend is not None:
for i in range(poses.shape[1]):
self.frontend.update_pose(i, poses[:,i].inv()[0].data)
return poses
def global_refinement(self):
""" run global refinement """
poses = self.poses[:, :self.count]
images = torch.cat(self.images_list, 1).cpu()
depths = torch.stack(self.depths_list, 1).cpu()
intrinsics = torch.stack(self.intrinsics_list, 1)
poses = self.forward(poses, images, depths, intrinsics, num_steps=16)
self.poses[:, :self.count] = poses
import torch
from lietorch import SO3, SE3, LieGroupParameter
import argparse
import numpy as np
import time
import torch.optim as optim
import torch.nn.functional as F
def draw(verticies):
""" draw pose graph """
import open3d as o3d
n = len(verticies)
points = np.array([x[1][:3] for x in verticies])
lines = np.stack([np.arange(0,n-1), np.arange(1,n)], 1)
line_set = o3d.geometry.LineSet(
points=o3d.utility.Vector3dVector(points),
lines=o3d.utility.Vector2iVector(lines),
)
o3d.visualization.draw_geometries([line_set])
def info2mat(info):
mat = np.zeros((6,6))
ix = 0
for i in range(mat.shape[0]):
mat[i,i:] = info[ix:ix+(6-i)]
mat[i:,i] = info[ix:ix+(6-i)]
ix += (6-i)
return mat
def read_g2o(fn):
verticies, edges = [], []
with open(fn) as f:
for line in f:
line = line.split()
if line[0] == 'VERTEX_SE3:QUAT':
v = int(line[1])
pose = np.array(line[2:], dtype=np.float32)
verticies.append([v, pose])
elif line[0] == 'EDGE_SE3:QUAT':
u = int(line[1])
v = int(line[2])
pose = np.array(line[3:10], dtype=np.float32)
info = np.array(line[10:], dtype=np.float32)
info = info2mat(info)
edges.append([u, v, pose, info, line])
return verticies, edges
def write_g2o(pose_graph, fn):
import csv
verticies, edges = pose_graph
with open(fn, 'w') as f:
writer = csv.writer(f, delimiter=' ')
for (v, pose) in verticies:
row = ['VERTEX_SE3:QUAT', v] + pose.tolist()
writer.writerow(row)
for edge in edges:
writer.writerow(edge[-1])
def reshaping_fn(dE, b=1.5):
""" Reshaping function from "Intrinsic consensus on SO(3), Tron et al."""
ang = dE.log.norm(dim=-1)
err = 1/b - (1/b + ang) * torch.exp(-b*ang)
return err.sum()
def gradient_initializer(pose_graph, n_steps=500, lr_init=0.2):
""" Riemannian Gradient Descent """
verticies, edges = pose_graph
# edge indicies (ii, jj)
ii = np.array([x[0] for x in edges])
jj = np.array([x[1] for x in edges])
ii = torch.from_numpy(ii).cuda()
jj = torch.from_numpy(jj).cuda()
Eij = np.stack([x[2][3:] for x in edges])
Eij = SO3(torch.from_numpy(Eij).float().cuda())
R = np.stack([x[1][3:] for x in verticies])
R = SO3(torch.from_numpy(R).float().cuda())
R = LieGroupParameter(R)
# use gradient descent with momentum
optimizer = optim.SGD([R], lr=lr_init, momentum=0.5)
start = time.time()
for i in range(n_steps):
optimizer.zero_grad()
for param_group in optimizer.param_groups:
param_group['lr'] = lr_init * .995**i
# rotation error
dE = (R[ii].inv() * R[jj]) * Eij.inv()
loss = reshaping_fn(dE)
loss.backward()
optimizer.step()
if i%25 == 0:
print(i, lr_init * .995**i, loss.item())
# convert rotations to pose3
quats = R.group.data.detach().cpu().numpy()
for i in range(len(verticies)):
verticies[i][1][3:] = quats[i]
return verticies, edges
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--problem', help="input pose graph optimization file (.g2o format)")
args = parser.parse_args()
output_path = args.problem.replace('.g2o', '_rotavg.g2o')
input_pose_graph = read_g2o(args.problem)
rot_pose_graph = gradient_initializer(input_pose_graph)
write_g2o(rot_pose_graph, output_path)
## Pose Graph Optimization / Rotation Averaging
Pose Graph Optimization (PGO) is the problem of estimating the global trajectory from a set of relative pose measurements. PGO is typically performed using nonlinear least-squares algorithms (e.g Levenberg-Marquardt) and requires a good initialization in order to converge.
In this experiment, we implement Riemannian Gradient Descent with a reshaping function (Tron et al. 2012). The algorithm is implemented in the function `gradient_initializer` and runs on the GPU using lietorch.
### Running on a .g2o file
Download a 3D problem from [datasets](https://lucacarlone.mit.edu/datasets/) (our implementation currently only supports uniform information matricies in Sphere-A, Torus, Cube, and Garage).
Then run the `gradient_initializer` on the problem
```python
python main.py --problem=torus3D.g2o --steps=500
```
The output graph, `torus3D_rotavg.g2o`, can then be used as the initialization for non-linear least squares optimizers such as `ceres`, `g2o`, and `gtsam`.
# Examples
Instructions for running demos and experiments can be found in each of the example directories
1. [Pose Graph Optimization](pgo/readme.md) -> `pgo`
1. [Sim3 Registration](registration/readme.md) -> `registration`
1. [RGBD-SLAM](rgbdslam/readme.md) -> `rgbdslam`
2. [RAFT-3D (SceneFlow)]()
`core` contains networks, data loaders, and other common utility functions.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment