Commit 92aa2fa8 authored by zachteed's avatar zachteed
Browse files

initial commit

parents
#!/usr/bin/python
# Software License Agreement (BSD License)
#
# Copyright (c) 2013, Juergen Sturm, TUM
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# * Neither the name of TUM nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Requirements:
# sudo apt-get install python-argparse
"""
The Kinect provides the color and depth images in an un-synchronized way. This means that the set of time stamps from the color images do not intersect with those of the depth images. Therefore, we need some way of associating color images to depth images.
For this purpose, you can use the ''associate.py'' script. It reads the time stamps from the rgb.txt file and the depth.txt file, and joins them by finding the best matches.
"""
import argparse
import sys
import os
import numpy
def read_file_list(filename):
"""
Reads a trajectory from a text file.
File format:
The file format is "stamp d1 d2 d3 ...", where stamp denotes the time stamp (to be matched)
and "d1 d2 d3.." is arbitary data (e.g., a 3D position and 3D orientation) associated to this timestamp.
Input:
filename -- File name
Output:
dict -- dictionary of (stamp,data) tuples
"""
file = open(filename)
data = file.read()
lines = data.replace(","," ").replace("\t"," ").split("\n")
list = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"]
list = [(float(l[0]),l[1:]) for l in list if len(l)>1]
return dict(list)
def associate(first_list, second_list,offset=0.0,max_difference=0.02):
"""
Associate two dictionaries of (stamp,data). As the time stamps never match exactly, we aim
to find the closest match for every input tuple.
Input:
first_list -- first dictionary of (stamp,data) tuples
second_list -- second dictionary of (stamp,data) tuples
offset -- time offset between both dictionaries (e.g., to model the delay between the sensors)
max_difference -- search radius for candidate generation
Output:
matches -- list of matched tuples ((stamp1,data1),(stamp2,data2))
"""
first_keys = list(first_list.keys())
second_keys = list(second_list.keys())
potential_matches = [(abs(a - (b + offset)), a, b)
for a in first_keys
for b in second_keys
if abs(a - (b + offset)) < max_difference]
potential_matches.sort()
matches = []
for diff, a, b in potential_matches:
if a in first_keys and b in second_keys:
first_keys.remove(a)
second_keys.remove(b)
matches.append((a, b))
matches.sort()
return matches
if __name__ == '__main__':
# parse command line
parser = argparse.ArgumentParser(description='''
This script takes two data files with timestamps and associates them
''')
parser.add_argument('first_file', help='first text file (format: timestamp data)')
parser.add_argument('second_file', help='second text file (format: timestamp data)')
parser.add_argument('--first_only', help='only output associated lines from first file', action='store_true')
parser.add_argument('--offset', help='time offset added to the timestamps of the second file (default: 0.0)',default=0.0)
parser.add_argument('--max_difference', help='maximally allowed time difference for matching entries (default: 0.02)',default=0.02)
args = parser.parse_args()
first_list = read_file_list(args.first_file)
second_list = read_file_list(args.second_file)
matches = associate(first_list, second_list,float(args.offset),float(args.max_difference))
if args.first_only:
for a,b in matches:
print("%f %s"%(a," ".join(first_list[a])))
else:
for a,b in matches:
print("%f %s %f %s"%(a," ".join(first_list[a]),b-float(args.offset)," ".join(second_list[b])))
#!/usr/bin/python
# Software License Agreement (BSD License)
#
# Copyright (c) 2013, Juergen Sturm, TUM
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# * Neither the name of TUM nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Requirements:
# sudo apt-get install python-argparse
"""
This script computes the absolute trajectory error from the ground truth
trajectory and the estimated trajectory.
"""
import sys
import numpy
import argparse
if __name__=="__main__":
import associate
else:
from . import associate
def align(model,data):
"""Align two trajectories using the method of Horn (closed-form).
Input:
model -- first trajectory (3xn)
data -- second trajectory (3xn)
Output:
rot -- rotation matrix (3x3)
trans -- translation vector (3x1)
trans_error -- translational error per point (1xn)
"""
numpy.set_printoptions(precision=3,suppress=True)
model_zerocentered = model - model.mean(1)
data_zerocentered = data - data.mean(1)
W = numpy.zeros( (3,3) )
for column in range(model.shape[1]):
W += numpy.outer(model_zerocentered[:,column],data_zerocentered[:,column])
U,d,Vh = numpy.linalg.linalg.svd(W.transpose())
S = numpy.matrix(numpy.identity( 3 ))
if(numpy.linalg.det(U) * numpy.linalg.det(Vh)<0):
S[2,2] = -1
rot = U*S*Vh
trans = data.mean(1) - rot * model.mean(1)
model_aligned = rot * model + trans
alignment_error = model_aligned - data
trans_error = numpy.sqrt(numpy.sum(numpy.multiply(alignment_error,alignment_error),0)).A[0]
return rot,trans,trans_error
def plot_traj(ax,stamps,traj,style,color,label):
"""
Plot a trajectory using matplotlib.
Input:
ax -- the plot
stamps -- time stamps (1xn)
traj -- trajectory (3xn)
style -- line style
color -- line color
label -- plot legend
"""
stamps.sort()
interval = numpy.median([s-t for s,t in zip(stamps[1:],stamps[:-1])])
x = []
y = []
last = stamps[0]
for i in range(len(stamps)):
if stamps[i]-last < 2*interval:
x.append(traj[i][0])
y.append(traj[i][1])
elif len(x)>0:
ax.plot(x,y,style,color=color,label=label)
label=""
x=[]
y=[]
last= stamps[i]
if len(x)>0:
ax.plot(x,y,style,color=color,label=label)
def evaluate_ate(first_list, second_list, _args=""):
# parse command line
parser = argparse.ArgumentParser(
description='This script computes the absolute trajectory error from the ground truth trajectory and the estimated trajectory.')
# parser.add_argument('first_file', help='ground truth trajectory (format: timestamp tx ty tz qx qy qz qw)')
# parser.add_argument('second_file', help='estimated trajectory (format: timestamp tx ty tz qx qy qz qw)')
parser.add_argument('--offset', help='time offset added to the timestamps of the second file (default: 0.0)',default=0.0)
parser.add_argument('--scale', help='scaling factor for the second trajectory (default: 1.0)',default=1.0)
parser.add_argument('--max_difference', help='maximally allowed time difference for matching entries (default: 0.02)',default=0.02)
parser.add_argument('--save', help='save aligned second trajectory to disk (format: stamp2 x2 y2 z2)')
parser.add_argument('--save_associations', help='save associated first and aligned second trajectory to disk (format: stamp1 x1 y1 z1 stamp2 x2 y2 z2)')
parser.add_argument('--plot', help='plot the first and the aligned second trajectory to an image (format: png)')
parser.add_argument('--verbose', help='print all evaluation data (otherwise, only the RMSE absolute translational error in meters after alignment will be printed)', action='store_true')
args = parser.parse_args(_args)
# first_list = associate.read_file_list(args.first_file)
# second_list = associate.read_file_list(args.second_file)
matches = associate.associate(first_list, second_list,float(args.offset),float(args.max_difference))
if len(matches)<2:
raise ValueError("Couldn't find matching timestamp pairs between groundtruth and estimated trajectory! Did you choose the correct sequence?")
first_xyz = numpy.matrix([[float(value) for value in first_list[a][0:3]] for a,b in matches]).transpose()
second_xyz = numpy.matrix([[float(value)*float(args.scale) for value in second_list[b][0:3]] for a,b in matches]).transpose()
rot,trans,trans_error = align(second_xyz,first_xyz)
second_xyz_aligned = rot * second_xyz + trans
first_stamps = list(first_list.keys())
first_stamps.sort()
first_xyz_full = numpy.matrix([[float(value) for value in first_list[b][0:3]] for b in first_stamps]).transpose()
second_stamps = list(second_list.keys())
second_stamps.sort()
second_xyz_full = numpy.matrix([[float(value)*float(args.scale) for value in second_list[b][0:3]] for b in second_stamps]).transpose()
second_xyz_full_aligned = rot * second_xyz_full + trans
if args.verbose:
print( "compared_pose_pairs %d pairs"%(len(trans_error)))
print( "absolute_translational_error.rmse %f m"%numpy.sqrt(numpy.dot(trans_error,trans_error) / len(trans_error)))
print( "absolute_translational_error.mean %f m"%numpy.mean(trans_error))
print( "absolute_translational_error.median %f m"%numpy.median(trans_error))
print( "absolute_translational_error.std %f m"%numpy.std(trans_error))
print( "absolute_translational_error.min %f m"%numpy.min(trans_error))
print( "absolute_translational_error.max %f m"%numpy.max(trans_error))
if args.save_associations:
file = open(args.save_associations,"w")
file.write("\n".join(["%f %f %f %f %f %f %f %f"%(a,x1,y1,z1,b,x2,y2,z2) for (a,b),(x1,y1,z1),(x2,y2,z2) in zip(matches,first_xyz.transpose().A,second_xyz_aligned.transpose().A)]))
file.close()
if args.save:
file = open(args.save,"w")
file.write("\n".join(["%f "%stamp+" ".join(["%f"%d for d in line]) for stamp,line in zip(second_stamps,second_xyz_full_aligned.transpose().A)]))
file.close()
if args.plot:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
from matplotlib.patches import Ellipse
fig = plt.figure()
ax = fig.add_subplot(111)
plot_traj(ax,first_stamps,first_xyz_full.transpose().A,'-',"black","ground truth")
plot_traj(ax,second_stamps,second_xyz_full_aligned.transpose().A,'-',"blue","estimated")
label="difference"
for (a,b),(x1,y1,z1),(x2,y2,z2) in zip(matches,first_xyz.transpose().A,second_xyz_aligned.transpose().A):
ax.plot([x1,x2],[y1,y2],'-',color="red",label=label)
label=""
ax.legend()
ax.set_xlabel('x [m]')
ax.set_ylabel('y [m]')
plt.savefig(args.plot,dpi=90)
return {
"compared_pose_pairs": (len(trans_error)),
"absolute_translational_error.rmse": numpy.sqrt(numpy.dot(trans_error,trans_error) / len(trans_error)),
"absolute_translational_error.mean": numpy.mean(trans_error),
"absolute_translational_error.median": numpy.median(trans_error),
"absolute_translational_error.std": numpy.std(trans_error),
"absolute_translational_error.min": numpy.min(trans_error),
"absolute_translational_error.max": numpy.max(trans_error),
}
#!/usr/bin/python
# Software License Agreement (BSD License)
#
# Copyright (c) 2013, Juergen Sturm, TUM
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# * Neither the name of TUM nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
This script computes the relative pose error from the ground truth trajectory
and the estimated trajectory.
"""
import argparse
import random
import numpy
import sys
_EPS = numpy.finfo(float).eps * 4.0
def transform44(l):
"""
Generate a 4x4 homogeneous transformation matrix from a 3D point and unit quaternion.
Input:
l -- tuple consisting of (stamp,tx,ty,tz,qx,qy,qz,qw) where
(tx,ty,tz) is the 3D position and (qx,qy,qz,qw) is the unit quaternion.
Output:
matrix -- 4x4 homogeneous transformation matrix
"""
t = l[1:4]
q = numpy.array(l[4:8], dtype=numpy.float64, copy=True)
nq = numpy.dot(q, q)
if nq < _EPS:
return numpy.array((
( 1.0, 0.0, 0.0, t[0])
( 0.0, 1.0, 0.0, t[1])
( 0.0, 0.0, 1.0, t[2])
( 0.0, 0.0, 0.0, 1.0)
), dtype=numpy.float64)
q *= numpy.sqrt(2.0 / nq)
q = numpy.outer(q, q)
return numpy.array((
(1.0-q[1, 1]-q[2, 2], q[0, 1]-q[2, 3], q[0, 2]+q[1, 3], t[0]),
( q[0, 1]+q[2, 3], 1.0-q[0, 0]-q[2, 2], q[1, 2]-q[0, 3], t[1]),
( q[0, 2]-q[1, 3], q[1, 2]+q[0, 3], 1.0-q[0, 0]-q[1, 1], t[2]),
( 0.0, 0.0, 0.0, 1.0)
), dtype=numpy.float64)
def read_trajectory(filename, matrix=True):
"""
Read a trajectory from a text file.
Input:
filename -- file to be read
matrix -- convert poses to 4x4 matrices
Output:
dictionary of stamped 3D poses
"""
file = open(filename)
data = file.read()
lines = data.replace(","," ").replace("\t"," ").split("\n")
list = [[float(v.strip()) for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"]
list_ok = []
for i,l in enumerate(list):
if l[4:8]==[0,0,0,0]:
continue
isnan = False
for v in l:
if numpy.isnan(v):
isnan = True
break
if isnan:
sys.stderr.write("Warning: line %d of file '%s' has NaNs, skipping line\n"%(i,filename))
continue
list_ok.append(l)
if matrix :
traj = dict([(l[0],transform44(l[0:])) for l in list_ok])
else:
traj = dict([(l[0],l[1:8]) for l in list_ok])
return traj
def find_closest_index(L,t):
"""
Find the index of the closest value in a list.
Input:
L -- the list
t -- value to be found
Output:
index of the closest element
"""
beginning = 0
difference = abs(L[0] - t)
best = 0
end = len(L)
while beginning < end:
middle = int((end+beginning)/2)
if abs(L[middle] - t) < difference:
difference = abs(L[middle] - t)
best = middle
if t == L[middle]:
return middle
elif L[middle] > t:
end = middle
else:
beginning = middle + 1
return best
def ominus(a,b):
"""
Compute the relative 3D transformation between a and b.
Input:
a -- first pose (homogeneous 4x4 matrix)
b -- second pose (homogeneous 4x4 matrix)
Output:
Relative 3D transformation from a to b.
"""
return numpy.dot(numpy.linalg.inv(a),b)
def scale(a,scalar):
"""
Scale the translational components of a 4x4 homogeneous matrix by a scale factor.
"""
return numpy.array(
[[a[0,0], a[0,1], a[0,2], a[0,3]*scalar],
[a[1,0], a[1,1], a[1,2], a[1,3]*scalar],
[a[2,0], a[2,1], a[2,2], a[2,3]*scalar],
[a[3,0], a[3,1], a[3,2], a[3,3]]]
)
def compute_distance(transform):
"""
Compute the distance of the translational component of a 4x4 homogeneous matrix.
"""
return numpy.linalg.norm(transform[0:3,3])
def compute_angle(transform):
"""
Compute the rotation angle from a 4x4 homogeneous matrix.
"""
# an invitation to 3-d vision, p 27
return numpy.arccos( min(1,max(-1, (numpy.trace(transform[0:3,0:3]) - 1)/2) ))
def distances_along_trajectory(traj):
"""
Compute the translational distances along a trajectory.
"""
keys = traj.keys()
keys.sort()
motion = [ominus(traj[keys[i+1]],traj[keys[i]]) for i in range(len(keys)-1)]
distances = [0]
sum = 0
for t in motion:
sum += compute_distance(t)
distances.append(sum)
return distances
def rotations_along_trajectory(traj,scale):
"""
Compute the angular rotations along a trajectory.
"""
keys = traj.keys()
keys.sort()
motion = [ominus(traj[keys[i+1]],traj[keys[i]]) for i in range(len(keys)-1)]
distances = [0]
sum = 0
for t in motion:
sum += compute_angle(t)*scale
distances.append(sum)
return distances
def evaluate_trajectory(traj_gt,traj_est,param_max_pairs=10000,param_fixed_delta=False,param_delta=1.00,param_delta_unit="s",param_offset=0.00,param_scale=1.00):
"""
Compute the relative pose error between two trajectories.
Input:
traj_gt -- the first trajectory (ground truth)
traj_est -- the second trajectory (estimated trajectory)
param_max_pairs -- number of relative poses to be evaluated
param_fixed_delta -- false: evaluate over all possible pairs
true: only evaluate over pairs with a given distance (delta)
param_delta -- distance between the evaluated pairs
param_delta_unit -- unit for comparison:
"s": seconds
"m": meters
"rad": radians
"deg": degrees
"f": frames
param_offset -- time offset between two trajectories (to model the delay)
param_scale -- scale to be applied to the second trajectory
Output:
list of compared poses and the resulting translation and rotation error
"""
stamps_gt = list(traj_gt.keys())
stamps_est = list(traj_est.keys())
stamps_gt.sort()
stamps_est.sort()
stamps_est_return = []
for t_est in stamps_est:
t_gt = stamps_gt[find_closest_index(stamps_gt,t_est + param_offset)]
t_est_return = stamps_est[find_closest_index(stamps_est,t_gt - param_offset)]
t_gt_return = stamps_gt[find_closest_index(stamps_gt,t_est_return + param_offset)]
if not t_est_return in stamps_est_return:
stamps_est_return.append(t_est_return)
if(len(stamps_est_return)<2):
raise Exception("Number of overlap in the timestamps is too small. Did you run the evaluation on the right files?")
if param_delta_unit=="s":
index_est = list(traj_est.keys())
index_est.sort()
elif param_delta_unit=="m":
index_est = distances_along_trajectory(traj_est)
elif param_delta_unit=="rad":
index_est = rotations_along_trajectory(traj_est,1)
elif param_delta_unit=="deg":
index_est = rotations_along_trajectory(traj_est,180/numpy.pi)
elif param_delta_unit=="f":
index_est = range(len(traj_est))
else:
raise Exception("Unknown unit for delta: '%s'"%param_delta_unit)
if not param_fixed_delta:
if(param_max_pairs==0 or len(traj_est)<numpy.sqrt(param_max_pairs)):
pairs = [(i,j) for i in range(len(traj_est)) for j in range(len(traj_est))]
else:
pairs = [(random.randint(0,len(traj_est)-1),random.randint(0,len(traj_est)-1)) for i in range(param_max_pairs)]
else:
pairs = []
for i in range(len(traj_est)):
j = find_closest_index(index_est,index_est[i] + param_delta)
if j!=len(traj_est)-1:
pairs.append((i,j))
if(param_max_pairs!=0 and len(pairs)>param_max_pairs):
pairs = random.sample(pairs,param_max_pairs)
gt_interval = numpy.median([s-t for s,t in zip(stamps_gt[1:],stamps_gt[:-1])])
gt_max_time_difference = 2*gt_interval
result = []
for i,j in pairs:
stamp_est_0 = stamps_est[i]
stamp_est_1 = stamps_est[j]
stamp_gt_0 = stamps_gt[ find_closest_index(stamps_gt,stamp_est_0 + param_offset) ]
stamp_gt_1 = stamps_gt[ find_closest_index(stamps_gt,stamp_est_1 + param_offset) ]
if(abs(stamp_gt_0 - (stamp_est_0 + param_offset)) > gt_max_time_difference or
abs(stamp_gt_1 - (stamp_est_1 + param_offset)) > gt_max_time_difference):
continue
error44 = ominus( scale(
ominus( traj_est[stamp_est_1], traj_est[stamp_est_0] ),param_scale),
ominus( traj_gt[stamp_gt_1], traj_gt[stamp_gt_0] ) )
trans = compute_distance(error44)
rot = compute_angle(error44)
result.append([stamp_est_0,stamp_est_1,stamp_gt_0,stamp_gt_1,trans,rot])
if len(result)<2:
raise Exception("Couldn't find matching timestamp pairs between groundtruth and estimated trajectory!")
return result
def percentile(seq,q):
"""
Return the q-percentile of a list
"""
seq_sorted = list(seq)
seq_sorted.sort()
return seq_sorted[int((len(seq_sorted)-1)*q)]
def evaluate_rpe(_args):
random.seed(0)
parser = argparse.ArgumentParser(description='''
This script computes the relative pose error from the ground truth trajectory and the estimated trajectory.
''')
parser.add_argument('groundtruth_file', help='ground-truth trajectory file (format: "timestamp tx ty tz qx qy qz qw")')
parser.add_argument('estimated_file', help='estimated trajectory file (format: "timestamp tx ty tz qx qy qz qw")')
parser.add_argument('--max_pairs', help='maximum number of pose comparisons (default: 10000, set to zero to disable downsampling)', default=10000)
parser.add_argument('--fixed_delta', help='only consider pose pairs that have a distance of delta delta_unit (e.g., for evaluating the drift per second/meter/radian)', action='store_true')
parser.add_argument('--delta', help='delta for evaluation (default: 1.0)',default=1.0)
parser.add_argument('--delta_unit', help='unit of delta (options: \'s\' for seconds, \'m\' for meters, \'rad\' for radians, \'f\' for frames; default: \'s\')',default='s')
parser.add_argument('--offset', help='time offset between ground-truth and estimated trajectory (default: 0.0)',default=0.0)
parser.add_argument('--scale', help='scaling factor for the estimated trajectory (default: 1.0)',default=1.0)
parser.add_argument('--save', help='text file to which the evaluation will be saved (format: stamp_est0 stamp_est1 stamp_gt0 stamp_gt1 trans_error rot_error)')
parser.add_argument('--plot', help='plot the result to a file (requires --fixed_delta, output format: png)')
parser.add_argument('--verbose', help='print all evaluation data (otherwise, only the mean translational error measured in meters will be printed)', action='store_true')
args = parser.parse_args(_args)
if args.plot and not args.fixed_delta:
raise ValueError("The '--plot' option can only be used in combination with '--fixed_delta'")
traj_gt = read_trajectory(args.groundtruth_file)
traj_est = read_trajectory(args.estimated_file)
result = evaluate_trajectory(traj_gt,
traj_est,
int(args.max_pairs),
args.fixed_delta,
float(args.delta),
args.delta_unit,
float(args.offset),
float(args.scale))
stamps = numpy.array(result)[:,0]
trans_error = numpy.array(result)[:,4]
rot_error = numpy.array(result)[:,5]
if args.save:
f = open(args.save,"w")
f.write("\n".join([" ".join(["%f"%v for v in line]) for line in result]))
f.close()
if args.verbose:
print( "compared_pose_pairs %d pairs"%(len(trans_error)))
print( "translational_error.rmse %f m"%numpy.sqrt(numpy.dot(trans_error,trans_error) / len(trans_error)))
print( "translational_error.mean %f m"%numpy.mean(trans_error))
print( "translational_error.median %f m"%numpy.median(trans_error))
print( "translational_error.std %f m"%numpy.std(trans_error))
print( "translational_error.min %f m"%numpy.min(trans_error))
print( "translational_error.max %f m"%numpy.max(trans_error))
print( "rotational_error.rmse %f deg"%(numpy.sqrt(numpy.dot(rot_error,rot_error) / len(rot_error)) * 180.0 / numpy.pi))
print( "rotational_error.mean %f deg"%(numpy.mean(rot_error) * 180.0 / numpy.pi))
print( "rotational_error.median %f deg"%(numpy.median(rot_error) * 180.0 / numpy.pi))
print( "rotational_error.std %f deg"%(numpy.std(rot_error) * 180.0 / numpy.pi))
print( "rotational_error.min %f deg"%(numpy.min(rot_error) * 180.0 / numpy.pi))
print( "rotational_error.max %f deg"%(numpy.max(rot_error) * 180.0 / numpy.pi))
else:
print( numpy.mean(trans_error))
if args.plot:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(stamps - stamps[0],trans_error,'-',color="blue")
#ax.plot([t for t,e in err_rot],[e for t,e in err_rot],'-',color="red")
ax.set_xlabel('time [s]')
ax.set_ylabel('translational error [m]')
plt.savefig(args.plot,dpi=300)
return {
"translational_error.rmse": numpy.sqrt(numpy.dot(trans_error,trans_error) / len(trans_error)),
"translational_error.mean": numpy.mean(trans_error),
"translational_error.median": numpy.median(trans_error),
"translational_error.std": numpy.std(trans_error),
"translational_error.min": numpy.min(trans_error),
"translational_error.max": numpy.max(trans_error),
"rotational_error.rmse": (numpy.sqrt(numpy.dot(rot_error,rot_error) / len(rot_error)) * 180.0 / numpy.pi),
"rotational_error.mean": (numpy.mean(rot_error) * 180.0 / numpy.pi),
"rotational_error.median": (numpy.median(rot_error) * 180.0 / numpy.pi),
"rotational_error.std": (numpy.std(rot_error) * 180.0 / numpy.pi),
"rotational_error.min": (numpy.min(rot_error) * 180.0 / numpy.pi),
"rotational_error.max": (numpy.max(rot_error) * 180.0 / numpy.pi),
}
if __name__ == '__main__':
random.seed(0)
parser = argparse.ArgumentParser(description='''
This script computes the relative pose error from the ground truth trajectory and the estimated trajectory.
''')
parser.add_argument('groundtruth_file', help='ground-truth trajectory file (format: "timestamp tx ty tz qx qy qz qw")')
parser.add_argument('estimated_file', help='estimated trajectory file (format: "timestamp tx ty tz qx qy qz qw")')
parser.add_argument('--max_pairs', help='maximum number of pose comparisons (default: 10000, set to zero to disable downsampling)', default=10000)
parser.add_argument('--fixed_delta', help='only consider pose pairs that have a distance of delta delta_unit (e.g., for evaluating the drift per second/meter/radian)', action='store_true')
parser.add_argument('--delta', help='delta for evaluation (default: 1.0)',default=1.0)
parser.add_argument('--delta_unit', help='unit of delta (options: \'s\' for seconds, \'m\' for meters, \'rad\' for radians, \'f\' for frames; default: \'s\')',default='s')
parser.add_argument('--offset', help='time offset between ground-truth and estimated trajectory (default: 0.0)',default=0.0)
parser.add_argument('--scale', help='scaling factor for the estimated trajectory (default: 1.0)',default=1.0)
parser.add_argument('--save', help='text file to which the evaluation will be saved (format: stamp_est0 stamp_est1 stamp_gt0 stamp_gt1 trans_error rot_error)')
parser.add_argument('--plot', help='plot the result to a file (requires --fixed_delta, output format: png)')
parser.add_argument('--verbose', help='print all evaluation data (otherwise, only the mean translational error measured in meters will be printed)', action='store_true')
args = parser.parse_args()
if args.plot and not args.fixed_delta:
sys.exit("The '--plot' option can only be used in combination with '--fixed_delta'")
traj_gt = read_trajectory(args.groundtruth_file)
traj_est = read_trajectory(args.estimated_file)
result = evaluate_trajectory(traj_gt,
traj_est,
int(args.max_pairs),
args.fixed_delta,
float(args.delta),
args.delta_unit,
float(args.offset),
float(args.scale))
stamps = numpy.array(result)[:,0]
trans_error = numpy.array(result)[:,4]
rot_error = numpy.array(result)[:,5]
if args.save:
f = open(args.save,"w")
f.write("\n".join([" ".join(["%f"%v for v in line]) for line in result]))
f.close()
if args.verbose:
print( "compared_pose_pairs %d pairs"%(len(trans_error)))
print( "translational_error.rmse %f m"%numpy.sqrt(numpy.dot(trans_error,trans_error) / len(trans_error)))
print( "translational_error.mean %f m"%numpy.mean(trans_error))
print( "translational_error.median %f m"%numpy.median(trans_error))
print( "translational_error.std %f m"%numpy.std(trans_error))
print( "translational_error.min %f m"%numpy.min(trans_error))
print( "translational_error.max %f m"%numpy.max(trans_error))
print( "rotational_error.rmse %f deg"%(numpy.sqrt(numpy.dot(rot_error,rot_error) / len(rot_error)) * 180.0 / numpy.pi))
print( "rotational_error.mean %f deg"%(numpy.mean(rot_error) * 180.0 / numpy.pi))
print( "rotational_error.median %f deg"%(numpy.median(rot_error) * 180.0 / numpy.pi))
print( "rotational_error.std %f deg"%(numpy.std(rot_error) * 180.0 / numpy.pi))
print( "rotational_error.min %f deg"%(numpy.min(rot_error) * 180.0 / numpy.pi))
print( "rotational_error.max %f deg"%(numpy.max(rot_error) * 180.0 / numpy.pi))
else:
print( numpy.mean(trans_error))
if args.plot:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(stamps - stamps[0],trans_error,'-',color="blue")
#ax.plot([t for t,e in err_rot],[e for t,e in err_rot],'-',color="red")
ax.set_xlabel('time [s]')
ax.set_ylabel('translational error [m]')
plt.savefig(args.plot,dpi=300)
import sys
sys.path.append('../core')
import cv2
import numpy as np
from collections import OrderedDict
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from data_readers.factory import dataset_factory
from lietorch import SO3, SE3, Sim3
from geom.losses import geodesic_loss, residual_loss
# network
from networks.rslam import RaftSLAM
from logger import Logger
from evaluate import run_evaluation
def show_image(image):
image = image.permute(1, 2, 0).cpu().numpy()
cv2.imshow('image', image / 255.0)
cv2.waitKey()
def normalize_images(images):
images = images[:, :, [2,1,0]]
mean = torch.as_tensor([0.485, 0.456, 0.406], device=images.device)
std = torch.as_tensor([0.229, 0.224, 0.225], device=images.device)
return (images/255.0).sub_(mean[:, None, None]).div_(std[:, None, None])
def train(args):
""" Test to make sure project transform correctly maps points """
N = args.n_frames
model = RaftSLAM(args)
model.cuda()
model.train()
if args.ckpt is not None:
model.load_state_dict(torch.load(args.ckpt))
db = dataset_factory(args.datasets, n_frames=N, fmin=16.0, fmax=96.0)
train_loader = DataLoader(db, batch_size=args.batch, shuffle=True, num_workers=4)
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer,
args.lr, args.steps, pct_start=0.01, cycle_momentum=False)
logger = Logger(args.name, scheduler)
should_keep_training = True
total_steps = 0
while should_keep_training:
for i_batch, item in enumerate(train_loader):
optimizer.zero_grad()
graph = OrderedDict()
for i in range(N):
graph[i] = [j for j in range(N) if i!=j and abs(i-j) <= 2]
images, poses, depths, intrinsics = [x.to('cuda') for x in item]
# convert poses w2c -> c2w
Ps = SE3(poses).inv()
Gs = SE3.Identity(Ps.shape, device='cuda')
images = normalize_images(images)
Gs, residuals = model(Gs, images, depths, intrinsics, graph, num_steps=args.iters)
geo_loss, geo_metrics = geodesic_loss(Ps, Gs, graph)
res_loss, res_metrics = residual_loss(residuals)
metrics = {}
metrics.update(geo_metrics)
metrics.update(res_metrics)
loss = args.w1 * geo_loss + args.w2 * res_loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
scheduler.step()
logger.push(metrics)
total_steps += 1
if total_steps % 10000 == 0:
PATH = 'checkpoints/%s_%06d.pth' % (args.name, total_steps)
torch.save(model.state_dict(), PATH)
run_evaluation(PATH)
if total_steps >= args.steps:
should_keep_training = False
break
return model
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='bla', help='name your experiment')
parser.add_argument('--ckpt', help='checkpoint to restore')
parser.add_argument('--datasets', nargs='+', help='lists of datasets for training')
parser.add_argument('--batch', type=int, default=2)
parser.add_argument('--iters', type=int, default=8)
parser.add_argument('--steps', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--clip', type=float, default=2.5)
parser.add_argument('--n_frames', type=int, default=4)
parser.add_argument('--w1', type=float, default=10.0)
parser.add_argument('--w2', type=float, default=0.1)
args = parser.parse_args()
import os
if not os.path.isdir('checkpoints'):
os.mkdir('checkpoints')
model = train(args)
import time
import argparse
import torch
import scipy
import numpy as np
import open3d as o3d
from queue import Empty
from multiprocessing import Queue, Process
from scipy.spatial.transform import Rotation
def pose_matrix_from_quaternion(pvec):
""" convert 4x4 pose matrix to (t, q) """
pose = np.eye(4)
pose[:3,:3] = Rotation.from_quat(pvec[3:]).as_matrix()
pose[:3, 3] = pvec[:3]
return pose
def create_camera_actor(is_gt=False, scale=0.05):
""" build open3d camera polydata """
cam_points = scale * np.array([
[ 0, 0, 0],
[-1, -1, 1.5],
[ 1, -1, 1.5],
[ 1, 1, 1.5],
[-1, 1, 1.5],
[-0.5, 1, 1.5],
[ 0.5, 1, 1.5],
[ 0, 1.2, 1.5]])
cam_lines = np.array([[1, 2], [2, 3], [3, 4], [4, 1],
[1, 0], [0, 2], [3, 0], [0, 4], [5, 7], [7, 6]])
camera_actor = o3d.geometry.LineSet(
points=o3d.utility.Vector3dVector(cam_points),
lines=o3d.utility.Vector2iVector(cam_lines))
color = (0.0, 0.0, 0.0) if is_gt else (0.0, 0.8, 0.8)
camera_actor.paint_uniform_color(color)
return camera_actor
def create_point_cloud_actor(points, colors):
""" open3d point cloud from numpy array """
point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(points)
point_cloud.colors = o3d.utility.Vector3dVector(colors)
return point_cloud
def draw_trajectory(queue):
draw_trajectory.queue = queue
draw_trajectory.cameras = {}
draw_trajectory.points = {}
draw_trajectory.ix = 0
draw_trajectory.warmup = 8
def animation_callback(vis):
cam = vis.get_view_control().convert_to_pinhole_camera_parameters()
while True:
try:
data = draw_trajectory.queue.get_nowait()
if data[0] == 'pose':
i, pose, is_gt = data[1:]
# convert to 4x4 matrix
pose = pose_matrix_from_quaternion(pose)
if i in draw_trajectory.cameras:
cam_actor, pose_prev = draw_trajectory.cameras[i]
pose_change = pose @ np.linalg.inv(pose_prev)
cam_actor.transform(pose_change)
vis.update_geometry(cam_actor)
if i in draw_trajectory.points:
pc = draw_trajectory.points[i]
pc.transform(pose_change)
vis.update_geometry(pc)
else:
cam_actor = create_camera_actor(is_gt)
cam_actor.transform(pose)
vis.add_geometry(cam_actor)
if not is_gt:
draw_trajectory.cameras[i] = (cam_actor, pose)
elif data[0] == 'points':
i, points, colors = data[1:]
point_actor = create_point_cloud_actor(points, colors)
pose = draw_trajectory.cameras[i][1]
point_actor.transform(pose)
vis.add_geometry(point_actor)
draw_trajectory.points[i] = point_actor
elif data[0] == 'reset':
draw_trajectory.warmup = -1
for i in draw_trajectory.points:
vis.remove_geometry(draw_trajectory.points[i])
for i in draw_trajectory.cameras:
vis.remove_geometry(draw_trajectory.cameras[i][0])
draw_trajectory.cameras = {}
draw_trajectory.points = {}
except Empty:
break
# hack to allow interacting with vizualization during inference
if len(draw_trajectory.cameras) >= draw_trajectory.warmup:
cam = vis.get_view_control().convert_from_pinhole_camera_parameters(cam)
vis.poll_events()
vis.update_renderer()
vis = o3d.visualization.Visualizer()
vis.register_animation_callback(animation_callback)
vis.create_window(height=540, width=960)
vis.get_render_option().load_from_json("assets/renderoption.json")
vis.run()
vis.destroy_window()
class SLAMFrontend:
def __init__(self):
self.queue = Queue()
self.p = Process(target=draw_trajectory, args=(self.queue, ))
def update_pose(self, index, pose, gt=False):
if isinstance(pose, torch.Tensor):
pose = pose.cpu().numpy()
self.queue.put_nowait(('pose', index, pose, gt))
def update_points(self, index, points, colors):
if isinstance(points, torch.Tensor):
points = points.cpu().numpy()
self.queue.put_nowait(('points', index, points, colors))
def reset(self):
self.queue.put_nowait(('reset', ))
def start(self):
self.p.start()
return self
def join(self):
self.p.join()
__all__ = ['groups']
from .groups import LieGroupParameter, SO3, RxSO3, SE3, Sim3, cat, stack
import torch
import numpy as np
def check_broadcastable(x, y):
assert len(x.shape) == len(y.shape)
for (n, m) in zip(x.shape[:-1], y.shape[:-1]):
assert n==m or n==1 or m==1
def broadcast_inputs(x, y):
""" Automatic broadcasting of missing dimensions """
if y is None:
xs, xd = x.shape[:-1], x.shape[-1]
return (x.view(-1, xd).contiguous(), ), x.shape[:-1]
check_broadcastable(x, y)
xs, xd = x.shape[:-1], x.shape[-1]
ys, yd = y.shape[:-1], y.shape[-1]
out_shape = [max(n,m) for (n,m) in zip(xs,ys)]
if x.shape[:-1] == y.shape[-1]:
x1 = x.view(-1, xd)
y1 = y.view(-1, yd)
else:
x_expand = [m if n==1 else 1 for (n,m) in zip(xs, ys)]
y_expand = [n if m==1 else 1 for (n,m) in zip(xs, ys)]
x1 = x.repeat(x_expand + [1]).reshape(-1, xd).contiguous()
y1 = y.repeat(y_expand + [1]).reshape(-1, yd).contiguous()
return (x1, y1), tuple(out_shape)
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#define BLOCK_H 4
#define BLOCK_W 8
#define BLOCK_HW BLOCK_H * BLOCK_W
#define CHANNEL_STRIDE 32
__forceinline__ __device__
bool within_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template <typename scalar_t>
__global__ void altcorr_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t x2s[BLOCK_HW];
__shared__ scalar_t y2s[BLOCK_HW];
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
auto fptr = fmap1[b][h1][w1];
if (within_bounds(h1, w1, H1, W1))
f1[c1][k1] = fptr[c+c1];
else
f1[c1][k1] = 0.0;
}
__syncthreads();
for (int n=0; n<N; n++) {
int h1 = h0 + threadIdx.x;
int w1 = w0 + threadIdx.y;
if (within_bounds(h1, w1, H1, W1)) {
x2s[tid] = coords[b][n][h1][w1][0];
y2s[tid] = coords[b][n][h1][w1][1];
}
scalar_t dx = x2s[tid] - floor(x2s[tid]);
scalar_t dy = y2s[tid] - floor(y2s[tid]);
int rd = 2*r + 1;
for (int iy=0; iy<rd+1; iy++) {
for (int ix=0; ix<rd+1; ix++) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
auto fptr = fmap2[b][h2][w2];
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fptr[c+c2];
else
f2[c2][k1] = 0.0;
}
__syncthreads();
scalar_t s = 0.0;
for (int k=0; k<CHANNEL_STRIDE; k++)
s += f1[k][tid] * f2[k][tid];
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
int ix_ne = H1*W1*((iy-1) + rd*ix);
int ix_sw = H1*W1*(iy + rd*(ix-1));
int ix_se = H1*W1*(iy + rd*ix);
scalar_t nw = s * (dy) * (dx);
scalar_t ne = s * (dy) * (1-dx);
scalar_t sw = s * (1-dy) * (dx);
scalar_t se = s * (1-dy) * (1-dx);
scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_nw) += nw;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_ne) += ne;
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_sw) += sw;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_se) += se;
}
}
}
}
}
template <typename scalar_t>
__global__ void altcorr_backward_kernel(
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t x2s[BLOCK_HW];
__shared__ scalar_t y2s[BLOCK_HW];
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
auto fptr = fmap1[b][h1][w1];
if (within_bounds(h1, w1, H1, W1))
f1[c1][k1] = fptr[c+c1];
else
f1[c1][k1] = 0.0;
f1_grad[c1][k1] = 0.0;
}
__syncthreads();
int h1 = h0 + threadIdx.x;
int w1 = w0 + threadIdx.y;
for (int n=0; n<N; n++) {
x2s[tid] = coords[b][n][h1][w1][0];
y2s[tid] = coords[b][n][h1][w1][1];
scalar_t dx = x2s[tid] - floor(x2s[tid]);
scalar_t dy = y2s[tid] - floor(y2s[tid]);
int rd = 2*r + 1;
for (int iy=0; iy<rd+1; iy++) {
for (int ix=0; ix<rd+1; ix++) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
auto fptr = fmap2[b][h2][w2];
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fptr[c+c2];
else
f2[c2][k1] = 0.0;
f2_grad[c2][k1] = 0.0;
}
__syncthreads();
const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
scalar_t g = 0.0;
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
int ix_ne = H1*W1*((iy-1) + rd*ix);
int ix_sw = H1*W1*(iy + rd*(ix-1));
int ix_se = H1*W1*(iy + rd*ix);
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_nw) * dy * dx;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_ne) * dy * (1-dx);
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_sw) * (1-dy) * dx;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
for (int k=0; k<CHANNEL_STRIDE; k++) {
f1_grad[k][tid] += g * f2[k][tid];
f2_grad[k][tid] += g * f1[k][tid];
}
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
if (within_bounds(h2, w2, H2, W2))
atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
}
}
}
}
__syncthreads();
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
if (within_bounds(h1, w1, H1, W1))
fptr[c+c1] += f1_grad[c1][k1];
}
}
}
std::vector<torch::Tensor> altcorr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H = coords.size(2);
const auto W = coords.size(3);
const auto rd = 2 * radius + 1;
auto opts = fmap1.options();
auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
altcorr_forward_kernel<float><<<blocks, threads>>>(
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
radius);
return {corr};
}
std::vector<torch::Tensor> altcorr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H1 = fmap1.size(1);
const auto W1 = fmap1.size(2);
const auto H2 = fmap2.size(1);
const auto W2 = fmap2.size(2);
const auto C = fmap1.size(3);
auto opts = fmap1.options();
auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
altcorr_backward_kernel<float><<<blocks, threads>>>(
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
radius);
return {fmap1_grad, fmap2_grad, coords_grad};
}
\ No newline at end of file
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#define BLOCK 16
__forceinline__ __device__ bool within_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template <typename scalar_t>
__global__ void corr_index_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> volume,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> coords,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
int r)
{
// batch index
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const int n = blockIdx.z;
const int h1 = volume.size(1);
const int w1 = volume.size(2);
const int h2 = volume.size(3);
const int w2 = volume.size(4);
if (!within_bounds(y, x, h1, w1)) {
return;
}
float x0 = coords[n][0][y][x];
float y0 = coords[n][1][y][x];
float dx = x0 - floor(x0);
float dy = y0 - floor(y0);
int rd = 2*r + 1;
for (int i=0; i<rd+1; i++) {
for (int j=0; j<rd+1; j++) {
int x1 = static_cast<int>(floor(x0)) - r + i;
int y1 = static_cast<int>(floor(y0)) - r + j;
if (within_bounds(y1, x1, h2, w2)) {
scalar_t s = volume[n][y][x][y1][x1];
if (i > 0 && j > 0)
corr[n][i-1][j-1][y][x] += s * scalar_t(dx * dy);
if (i > 0 && j < rd)
corr[n][i-1][j][y][x] += s * scalar_t(dx * (1.0f-dy));
if (i < rd && j > 0)
corr[n][i][j-1][y][x] += s * scalar_t((1.0f-dx) * dy);
if (i < rd && j < rd)
corr[n][i][j][y][x] += s * scalar_t((1.0f-dx) * (1.0f-dy));
}
}
}
}
template <typename scalar_t>
__global__ void corr_index_backward_kernel(
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> coords,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> volume_grad,
int r)
{
// batch index
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const int n = blockIdx.z;
const int h1 = volume_grad.size(1);
const int w1 = volume_grad.size(2);
const int h2 = volume_grad.size(3);
const int w2 = volume_grad.size(4);
if (!within_bounds(y, x, h1, w1)) {
return;
}
float x0 = coords[n][0][y][x];
float y0 = coords[n][1][y][x];
float dx = x0 - floor(x0);
float dy = y0 - floor(y0);
int rd = 2*r + 1;
for (int i=0; i<rd+1; i++) {
for (int j=0; j<rd+1; j++) {
int x1 = static_cast<int>(floor(x0)) - r + i;
int y1 = static_cast<int>(floor(y0)) - r + j;
if (within_bounds(y1, x1, h2, w2)) {
scalar_t g = 0.0;
if (i > 0 && j > 0)
g += corr_grad[n][i-1][j-1][y][x] * scalar_t(dx * dy);
if (i > 0 && j < rd)
g += corr_grad[n][i-1][j][y][x] * scalar_t(dx * (1.0f-dy));
if (i < rd && j > 0)
g += corr_grad[n][i][j-1][y][x] * scalar_t((1.0f-dx) * dy);
if (i < rd && j < rd)
g += corr_grad[n][i][j][y][x] * scalar_t((1.0f-dx) * (1.0f-dy));
volume_grad[n][y][x][y1][x1] += g;
}
}
}
}
std::vector<torch::Tensor> corr_index_cuda_forward(
torch::Tensor volume,
torch::Tensor coords,
int radius)
{
const auto batch_size = volume.size(0);
const auto ht = volume.size(1);
const auto wd = volume.size(2);
const dim3 blocks((wd + BLOCK - 1) / BLOCK,
(ht + BLOCK - 1) / BLOCK,
batch_size);
const dim3 threads(BLOCK, BLOCK);
auto opts = volume.options();
torch::Tensor corr = torch::zeros(
{batch_size, 2*radius+1, 2*radius+1, ht, wd}, opts);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_forward_kernel", ([&] {
corr_index_forward_kernel<scalar_t><<<blocks, threads>>>(
volume.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
corr.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
radius);
}));
return {corr};
}
std::vector<torch::Tensor> corr_index_cuda_backward(
torch::Tensor volume,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius)
{
const auto batch_size = volume.size(0);
const auto ht = volume.size(1);
const auto wd = volume.size(2);
auto volume_grad = torch::zeros_like(volume);
const dim3 blocks((wd + BLOCK - 1) / BLOCK,
(ht + BLOCK - 1) / BLOCK,
batch_size);
const dim3 threads(BLOCK, BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_backward_kernel", ([&] {
corr_index_backward_kernel<scalar_t><<<blocks, threads>>>(
coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
corr_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
volume_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
radius);
}));
return {volume_grad};
}
#include <torch/extension.h>
#include <vector>
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// CUDA forward declarations
std::vector<torch::Tensor> corr_index_cuda_forward(
torch::Tensor volume,
torch::Tensor coords,
int radius);
std::vector<torch::Tensor> corr_index_cuda_backward(
torch::Tensor volume,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius);
std::vector<torch::Tensor> altcorr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius);
std::vector<torch::Tensor> altcorr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius);
std::vector<torch::Tensor> dense_se3_forward_cuda(
torch::Tensor transforms,
torch::Tensor embeddings,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
int radius);
std::vector<torch::Tensor> dense_se3_backward_cuda(
torch::Tensor transforms,
torch::Tensor embeddings,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
torch::Tensor H_grad,
torch::Tensor b_grad,
int radius);
std::vector<torch::Tensor> se3_build_cuda(
torch::Tensor attention,
torch::Tensor transforms,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
int radius);
std::vector<torch::Tensor> se3_build_backward_cuda(
torch::Tensor attention,
torch::Tensor transforms,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
torch::Tensor H_grad,
torch::Tensor b_grad,
int radius);
std::vector<torch::Tensor> cholesky_solve6x6_forward_cuda(
torch::Tensor H, torch::Tensor b);
std::vector<torch::Tensor> cholesky_solve6x6_backward_cuda(
torch::Tensor H, torch::Tensor b, torch::Tensor dx);
// c++ python binding
std::vector<torch::Tensor> corr_index_forward(
torch::Tensor volume,
torch::Tensor coords,
int radius) {
CHECK_INPUT(volume);
CHECK_INPUT(coords);
return corr_index_cuda_forward(volume, coords, radius);
}
std::vector<torch::Tensor> corr_index_backward(
torch::Tensor volume,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius) {
CHECK_INPUT(volume);
CHECK_INPUT(coords);
CHECK_INPUT(corr_grad);
auto volume_grad = corr_index_cuda_backward(volume, coords, corr_grad, radius);
return {volume_grad};
}
std::vector<torch::Tensor> altcorr_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);
return altcorr_cuda_forward(fmap1, fmap2, coords, radius);
}
std::vector<torch::Tensor> altcorr_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);
CHECK_INPUT(corr_grad);
return altcorr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
}
std::vector<torch::Tensor> se3_build(
torch::Tensor attention,
torch::Tensor transforms,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
int radius) {
CHECK_INPUT(transforms);
CHECK_INPUT(attention);
CHECK_INPUT(points);
CHECK_INPUT(targets);
CHECK_INPUT(weights);
CHECK_INPUT(intrinsics);
return se3_build_cuda(attention, transforms,
points, targets, weights, intrinsics, radius);
}
std::vector<torch::Tensor> se3_build_backward(
torch::Tensor attention,
torch::Tensor transforms,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
torch::Tensor H_grad,
torch::Tensor b_grad,
int radius) {
CHECK_INPUT(transforms);
CHECK_INPUT(attention);
CHECK_INPUT(points);
CHECK_INPUT(targets);
CHECK_INPUT(weights);
CHECK_INPUT(intrinsics);
CHECK_INPUT(H_grad);
CHECK_INPUT(b_grad);
return se3_build_backward_cuda(attention, transforms, points,
targets, weights, intrinsics, H_grad, b_grad, radius);
}
std::vector<torch::Tensor> se3_build_inplace(
torch::Tensor transforms,
torch::Tensor embeddings,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
int radius) {
CHECK_INPUT(transforms);
CHECK_INPUT(embeddings);
CHECK_INPUT(points);
CHECK_INPUT(targets);
CHECK_INPUT(weights);
CHECK_INPUT(intrinsics);
return dense_se3_forward_cuda(transforms, embeddings,
points, targets, weights, intrinsics, radius);
}
std::vector<torch::Tensor> se3_build_inplace_backward(
torch::Tensor transforms,
torch::Tensor embeddings,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
torch::Tensor H_grad,
torch::Tensor b_grad,
int radius) {
CHECK_INPUT(transforms);
CHECK_INPUT(embeddings);
CHECK_INPUT(points);
CHECK_INPUT(targets);
CHECK_INPUT(weights);
CHECK_INPUT(intrinsics);
CHECK_INPUT(H_grad);
CHECK_INPUT(b_grad);
return dense_se3_backward_cuda(transforms, embeddings, points,
targets, weights, intrinsics, H_grad, b_grad, radius);
}
std::vector<torch::Tensor> cholesky6x6_forward(
torch::Tensor H,
torch::Tensor b) {
CHECK_INPUT(H);
CHECK_INPUT(b);
return cholesky_solve6x6_forward_cuda(H, b);
}
std::vector<torch::Tensor> cholesky6x6_backward(
torch::Tensor H,
torch::Tensor b,
torch::Tensor dx) {
CHECK_INPUT(H);
CHECK_INPUT(b);
CHECK_INPUT(dx);
return cholesky_solve6x6_backward_cuda(H, b, dx);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("altcorr_forward", &altcorr_forward, "ALTCORR forward");
m.def("altcorr_backward", &altcorr_backward, "ALTCORR backward");
m.def("corr_index_forward", &corr_index_forward, "INDEX forward");
m.def("corr_index_backward", &corr_index_backward, "INDEX backward");
// RAFT-3D functions
m.def("se3_build", &se3_build, "build forward");
m.def("se3_build_backward", &se3_build_backward, "build backward");
m.def("se3_build_inplace", &se3_build_inplace, "build forward");
m.def("se3_build_inplace_backward", &se3_build_inplace_backward, "build backward");
m.def("cholesky6x6_forward", &cholesky6x6_forward, "solve forward");
m.def("cholesky6x6_backward", &cholesky6x6_backward, "solve backward");
}
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <iostream>
#define NUM_THREADS 64
// #define RADIUS 32
__global__ void se3_build_forward_kernel(
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> attention,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> transforms,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> points,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> targets,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> weights,
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> intrinsics,
torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> Hx,
torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> bx,
int radius)
{
/* Dense transform layer aggregation step
Inputs:
attention: [B, H, W, H, W]
transforms: [B, H, W, 4, 4]
points: [B, 3, H, W]
targets: [B, 2, H, W]
weights: [B, 2, H, W]
intrinsics: [B, 4]
Outputs:
Hx: [B, H, W, 6, 6]
bx: [B, H, W, 6, 1]
*/
int batch_id = blockIdx.x; // batch_index
int tx = threadIdx.x;
int ix = blockIdx.y * NUM_THREADS + tx; // image_index
int ht = attention.size(1);
int wd = attention.size(2);
int dim = ht * wd;
int h1 = ix / wd;
int w1 = ix % wd;
const float* Gdata = transforms[batch_id].data();
const float* Xdata = points[batch_id].data();
const float* rdata = targets[batch_id].data();
const float* wdata = weights[batch_id].data();
__shared__ float fx, fy, cx, cy;
if (tx == 0) {
fx = intrinsics[batch_id][0];
fy = intrinsics[batch_id][1];
cx = intrinsics[batch_id][2];
cy = intrinsics[batch_id][3];
}
float G[12];
if (ix < dim) {
for (int k=0; k<12; k++)
G[k] = Gdata[ix + k*dim];
}
// linear system
float H[6][6];
float b[6];
for (int ii=0; ii<6; ii++) {
b[ii] = 0.0f;
for (int jj=0; jj<6; jj++) {
H[ii][jj] = 0.0f;
}
}
// jacobians
float Ju[6];
float Jv[6];
float Jz[6];
__shared__ float X0[3][NUM_THREADS];
__shared__ float rvec[3][NUM_THREADS];
__shared__ float wvec[3][NUM_THREADS];
__syncthreads();
for (int i=0; i<dim; i+=NUM_THREADS) {
// load in data
int jx = i + tx;
if (jx < dim) {
X0[0][tx] = Xdata[jx+0*dim];
X0[1][tx] = Xdata[jx+1*dim];
X0[2][tx] = Xdata[jx+2*dim];
rvec[0][tx] = rdata[jx+0*dim];
rvec[1][tx] = rdata[jx+1*dim];
rvec[2][tx] = rdata[jx+2*dim];
wvec[0][tx] = wdata[jx+0*dim];
wvec[1][tx] = wdata[jx+1*dim];
wvec[2][tx] = wdata[jx+2*dim];
}
__syncthreads();
for (int j=0; j<NUM_THREADS; j++) {
jx = i + j;
if (ix<dim && jx<dim) {
int h2 = jx / wd;
int w2 = jx % wd;
int r = max(abs(h1-h2), abs(w1-w2));
if (r > radius)
continue;
float w = attention[batch_id][h1][w1][h2][w2];
float wu = w * wvec[0][j];
float wv = w * wvec[1][j];
float wz = w * wvec[2][j];
float X1, Y1, Z1;
X1 = G[0]*X0[0][j] + G[1]*X0[1][j] + G[2]*X0[2][j] + G[3];
Y1 = G[4]*X0[0][j] + G[5]*X0[1][j] + G[6]*X0[2][j] + G[7];
Z1 = G[8]*X0[0][j] + G[9]*X0[1][j] + G[10]*X0[2][j] + G[11];
if (Z1 < 0.1) Z1 = 0.001;
// residual vectors
float ru = rvec[0][j] - (fx * (X1 / Z1) + cx);
float rv = rvec[1][j] - (fy * (Y1 / Z1) + cy);
float rz = rvec[2][j] - (1.0 / Z1);
if (abs(ru) > 250 || abs(rv) > 250 || Z1 < 0.1) {
continue;
}
float d = 1.f/Z1;
float d2 = d*d;
// x-jacobians
Ju[0] = fx * d;
Ju[1] = fx * 0.0;
Ju[2] = fx * (-X1*d2);
Ju[3] = fx * (-X1*Y1*d2);
Ju[4] = fx * (1 + X1*X1*d2);
Ju[5] = fx * (-Y1*d);
// y-jacobians
Jv[0] = fy * 0.0;
Jv[1] = fy * d;
Jv[2] = fy * (-Y1*d2);
Jv[3] = fy * -1 * (1+Y1*Y1*d2);
Jv[4] = fy * X1*Y1*d2;
Jv[5] = fy * X1*d;
// z-jacobians
Jz[0] = 0.0;
Jz[1] = 0.0;
Jz[2] = -d2;
Jz[3] = d * Y1;
Jz[4] = -d * X1;
Jz[5] = 0.0;
for (int ii=0; ii<6; ii++) {
b[ii] += wu*ru*Ju[ii] + wv*rv*Jv[ii] + wz*rz*Jz[ii];
for (int jj=0; jj<6; jj++) {
H[ii][jj] += wu*Ju[ii]*Ju[jj] + wv*Jv[ii]*Jv[jj] + wz*Jz[ii]*Jz[jj];
}
}
}
}
__syncthreads();
}
if (ix < dim) {
for (int ii=0; ii<6; ii++) {
bx[batch_id][ii][0][h1][w1] = b[ii];
for (int jj=0; jj<6; jj++) {
Hx[batch_id][ii][jj][h1][w1] = H[ii][jj];
}
}
}
}
__global__ void se3_build_backward_kernel(
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> attention,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> transforms,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> points,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> targets,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> weights,
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> intrinsics,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> Hx_grad,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> bx_grad,
torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> attention_grad,
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> targets_grad,
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> weights_grad,
int radius)
{
int batch_id = blockIdx.x; // batch_index
int tx = threadIdx.x;
int ix = blockIdx.y * NUM_THREADS + tx; // image_index
int ht = attention.size(1);
int wd = attention.size(2);
int dim = ht * wd;
int h2 = ix / wd;
int w2 = ix % wd;
const float* Gdata = transforms[batch_id].data();
const float* Hdata = Hx_grad[batch_id].data();
const float* bdata = bx_grad[batch_id].data();
__shared__ float fx, fy, cx, cy;
if (tx == 0) {
fx = intrinsics[batch_id][0];
fy = intrinsics[batch_id][1];
cx = intrinsics[batch_id][2];
cy = intrinsics[batch_id][3];
}
float X0[3];
X0[0] = points[batch_id][0][h2][w2];
X0[1] = points[batch_id][1][h2][w2];
X0[2] = points[batch_id][2][h2][w2];
float target_u = targets[batch_id][0][h2][w2];
float target_v = targets[batch_id][1][h2][w2];
float target_z = targets[batch_id][2][h2][w2];
float wu = weights[batch_id][0][h2][w2];
float wv = weights[batch_id][1][h2][w2];
float wz = weights[batch_id][2][h2][w2];
// jacobians
float Ju[6], Jv[6], Jz[6];
float diff_ru = 0.0f;
float diff_rv = 0.0f;
float diff_rz = 0.0f;
float diff_wu = 0.0f;
float diff_wv = 0.0f;
float diff_wz = 0.0f;
__shared__ float Gs[12][NUM_THREADS];
__shared__ float H_grad[36][NUM_THREADS];
__shared__ float b_grad[6][NUM_THREADS];
__syncthreads();
for (int i=0; i<dim; i+=NUM_THREADS) {
int jx = i + tx;
if (jx < dim) {
for (int k=0; k<12; k++)
Gs[k][tx] = Gdata[jx + k*dim];
for (int k=0; k<36; k++)
H_grad[k][tx] = Hdata[jx + k*dim];
for (int k=0; k<6; k++)
b_grad[k][tx] = bdata[jx + k*dim];
}
__syncthreads();
for (int j=0; j<NUM_THREADS; j++) {
jx = i + j;
if (ix<dim && jx<dim) {
int h1 = jx / wd;
int w1 = jx % wd;
int r = max(abs(h1-h2), abs(w1-w2));
if (r > radius)
continue;
float w = attention[batch_id][h1][w1][h2][w2];
float diff_w = 0.0f;
float X1, Y1, Z1;
X1 = Gs[0][j]*X0[0] + Gs[1][j]*X0[1] + Gs[2][j]*X0[2] + Gs[3][j];
Y1 = Gs[4][j]*X0[0] + Gs[5][j]*X0[1] + Gs[6][j]*X0[2] + Gs[7][j];
Z1 = Gs[8][j]*X0[0] + Gs[9][j]*X0[1] + Gs[10][j]*X0[2] + Gs[11][j];
if (Z1 < 0.1) Z1 = 0.001;
// residual vectors
float ru = target_u - (fx * (X1 / Z1) + cx);
float rv = target_v - (fy * (Y1 / Z1) + cy);
float rz = target_z - (1.0 / Z1);
if (abs(ru) > 50 || abs(rv) > 50 || Z1 < 0.1) {
continue;
}
float d = 1.f/Z1;
float d2 = d*d;
// x-jacobians
Ju[0] = fx * d;
Ju[1] = fx * 0.0;
Ju[2] = fx * (-X1*d2);
Ju[3] = fx * (-X1*Y1*d2);
Ju[4] = fx * (1 + X1*X1*d2);
Ju[5] = fx * (-Y1*d);
// y-jacobians
Jv[0] = fy * 0.0;
Jv[1] = fy * d;
Jv[2] = fy * (-Y1*d2);
Jv[3] = fy * -1 * (1+Y1*Y1*d2);
Jv[4] = fy * X1*Y1*d2;
Jv[5] = fy * X1*d;
// z-jacobians
Jz[0] = 0.0;
Jz[1] = 0.0;
Jz[2] = -d2;
Jz[3] = d * Y1;
Jz[4] = -d * X1;
Jz[5] = 0.0;
for (int ii=0; ii<6; ii++) {
// residual gradients
diff_ru += w*wu*Ju[ii]*b_grad[ii][j];
diff_rv += w*wv*Jv[ii]*b_grad[ii][j];
diff_rz += w*wz*Jz[ii]*b_grad[ii][j];
// weights gradients
diff_wu += w*ru*Ju[ii]*b_grad[ii][j];
diff_wv += w*rv*Jv[ii]*b_grad[ii][j];
diff_wz += w*rz*Jz[ii]*b_grad[ii][j];
// embedding weight
diff_w += (wu*ru*Ju[ii] + wv*rv*Jv[ii] + wz*rz*Jz[ii]) * b_grad[ii][j];
for (int jj=0; jj<6; jj++) {
diff_wu += w*Ju[ii]*Ju[jj]*H_grad[6*ii+jj][j];
diff_wv += w*Jv[ii]*Jv[jj]*H_grad[6*ii+jj][j];
diff_wz += w*Jz[ii]*Jz[jj]*H_grad[6*ii+jj][j];
diff_w += (wu*Ju[ii]*Ju[jj] + wv*Jv[ii]*Jv[jj] + wz*Jz[ii]*Jz[jj])*H_grad[6*ii+jj][j];
}
}
attention_grad[batch_id][h1][w1][h2][w2] = diff_w;
}
}
__syncthreads();
}
targets_grad[batch_id][0][h2][w2] = diff_ru;
targets_grad[batch_id][1][h2][w2] = diff_rv;
targets_grad[batch_id][2][h2][w2] = diff_rz;
weights_grad[batch_id][0][h2][w2] = diff_wu;
weights_grad[batch_id][1][h2][w2] = diff_wv;
weights_grad[batch_id][2][h2][w2] = diff_wz;
}
std::vector<torch::Tensor> se3_build_cuda(
torch::Tensor attention,
torch::Tensor transforms,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
int radius)
{
int batch_size = attention.size(0);
int ht = attention.size(1);
int wd = attention.size(2);
dim3 grid = dim3(batch_size, (ht*wd + NUM_THREADS-1) / NUM_THREADS);
auto opts = attention.options();
torch::Tensor H = torch::zeros({batch_size, 6, 6, ht, wd}, opts);
torch::Tensor b = torch::zeros({batch_size, 6, 1, ht, wd}, opts);
se3_build_forward_kernel<<<grid, NUM_THREADS>>>(
attention.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
transforms.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
points.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
targets.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
weights.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
intrinsics.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
H.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
b.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
radius);
return {H, b};
}
std::vector<torch::Tensor> se3_build_backward_cuda(
torch::Tensor attention,
torch::Tensor transforms,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
torch::Tensor H_grad,
torch::Tensor b_grad,
int radius)
{
int batch_size = attention.size(0);
int ht = attention.size(1);
int wd = attention.size(2);
dim3 grid = dim3(batch_size, (ht*wd + NUM_THREADS-1) / NUM_THREADS);
torch::Tensor attention_grad = torch::zeros_like(attention);
torch::Tensor targets_grad = torch::zeros_like(targets);
torch::Tensor weights_grad = torch::zeros_like(weights);
se3_build_backward_kernel<<<grid, NUM_THREADS>>>(
attention.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
transforms.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
points.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
targets.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
weights.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
intrinsics.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
H_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
b_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
attention_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
targets_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
weights_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
radius);
return {attention_grad, targets_grad, weights_grad};
}
\ No newline at end of file
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <iostream>
#define NUM_THREADS 64
#define AE_DIM 32
__device__ __forceinline__ float sigmoid(float x) {
return exp(x) / (exp(x) + 1.0);
}
__device__ __forceinline__ void
se3_transform_point_inplace(const float T[7], float X[3]) {
const float tx=T[0], ty=T[1], tz=T[2];
const float qx=T[3], qy=T[4], qz=T[5], qw=T[6];
float uv[3];
uv[0] = 2.0 * (qy*X[2] - qz*X[1]);
uv[1] = 2.0 * (qz*X[0] - qx*X[2]);
uv[2] = 2.0 * (qx*X[1] - qy*X[0]);
X[0] += qw*uv[0] + (qy*uv[2] - qz*uv[1]) + tx;
X[1] += qw*uv[1] + (qz*uv[0] - qx*uv[2]) + ty;
X[2] += qw*uv[2] + (qx*uv[1] - qy*uv[0]) + tz;
}
__device__ __forceinline__ void
pinhole_jacobians(const float p[3], const float fx, const float fy, float Ju[6], float Jv[6], float Jz[6]) {
const float X1=p[0], Y1=p[1], Z1=p[2];
const float d = 1.0 / Z1;
const float d2 = d * d;
// x-jacobians
Ju[0] = fx * d;
Ju[1] = fx * 0.0;
Ju[2] = fx * (-X1*d2);
Ju[3] = fx * (-X1*Y1*d2);
Ju[4] = fx * (1 + X1*X1*d2);
Ju[5] = fx * (-Y1*d);
// y-jacobians
Jv[0] = fy * 0.0;
Jv[1] = fy * d;
Jv[2] = fy * (-Y1*d2);
Jv[3] = fy * -1 * (1+Y1*Y1*d2);
Jv[4] = fy * X1*Y1*d2;
Jv[5] = fy * X1*d;
// z-jacobians
Jz[0] = 0.0;
Jz[1] = 0.0;
Jz[2] = -d2;
Jz[3] = d * Y1;
Jz[4] = -d * X1;
Jz[5] = 0.0;
}
__global__ void dense_se3_forward_kernel(
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> transforms,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> embeddings,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> points,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> targets,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> weights,
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> intrinsics,
torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> Hx,
torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> bx,
int radius)
{
int batch_id = blockIdx.x; // batch_index
int tx = threadIdx.x;
int ix = blockIdx.y * NUM_THREADS + tx; // image_index
const int ht = transforms.size(2);
const int wd = transforms.size(3);
const int ae_dim = embeddings.size(1);
const int dim = ht * wd;
const int h1 = ix / wd;
const int w1 = ix % wd;
const float* Xdata = points[batch_id].data();
const float* rdata = targets[batch_id].data();
const float* wdata = weights[batch_id].data();
const float* ae_data = embeddings[batch_id].data();
__shared__ float fx, fy, cx, cy;
if (tx == 0) {
fx = intrinsics[batch_id][0];
fy = intrinsics[batch_id][1];
cx = intrinsics[batch_id][2];
cy = intrinsics[batch_id][3];
}
// transformation
float G[7];
float ae1[AE_DIM];
// linear system
float H[6][6], b[6];
if (ix < dim) {
G[0] = transforms[batch_id][0][h1][w1]; // tx
G[1] = transforms[batch_id][1][h1][w1]; // ty
G[2] = transforms[batch_id][2][h1][w1]; // tz
G[3] = transforms[batch_id][3][h1][w1]; // qx
G[4] = transforms[batch_id][4][h1][w1]; // qy
G[5] = transforms[batch_id][5][h1][w1]; // qz
G[6] = transforms[batch_id][6][h1][w1]; // qw
for (int ii=0; ii<ae_dim; ii++) {
ae1[ii] = embeddings[batch_id][ii][h1][w1];
}
}
for (int ii=0; ii<6; ii++) {
b[ii] = 0;
}
for (int ii=0; ii<6; ii++) {
for (int jj=0; jj<6; jj++) {
H[ii][jj] = 0;
}
}
// jacobians
float Ju[6], Jv[6], Jz[6];
__shared__ float X0[3][NUM_THREADS];
__shared__ float ae2[AE_DIM][NUM_THREADS];
__shared__ float rvec[3][NUM_THREADS];
__shared__ float wvec[3][NUM_THREADS];
__syncthreads();
for (int i=0; i<dim; i+=NUM_THREADS) {
// load in data
int jx = i + tx;
if (jx < dim) {
X0[0][tx] = Xdata[jx+0*dim];
X0[1][tx] = Xdata[jx+1*dim];
X0[2][tx] = Xdata[jx+2*dim];
rvec[0][tx] = rdata[jx+0*dim];
rvec[1][tx] = rdata[jx+1*dim];
rvec[2][tx] = rdata[jx+2*dim];
wvec[0][tx] = wdata[jx+0*dim];
wvec[1][tx] = wdata[jx+1*dim];
wvec[2][tx] = wdata[jx+2*dim];
for (int k=0; k<ae_dim; k++)
ae2[k][tx] = ae_data[jx + k*dim];
}
__syncthreads();
for (int j=0; j<NUM_THREADS; j++) {
jx = i + j;
if (ix<dim && jx<dim) {
int h2 = jx / wd;
int w2 = jx % wd;
int r = max(abs(h1-h2), abs(w1-w2));
if (r > radius)
continue;
float p[3] = { X0[0][j], X0[1][j], X0[2][j] };
se3_transform_point_inplace(G, p);
// residual vectors
const float X1=p[0], Y1=p[1], Z1=p[2];
const float u = fx * (X1 / Z1) + cx;
const float v = fy * (Y1 / Z1) + cy;
const float ru = rvec[0][j] - u;
const float rv = rvec[1][j] - v;
const float rz = rvec[2][j] - 1.0 / Z1;
// exclude pixels too close or errors too big
if (Z1 < 0.1 || abs(ru) > 128 || abs(rv) > 128)
continue;
float s=0.0;
for (int k=0; k<ae_dim; k++) {
s += (ae1[k] - ae2[k][j]) * (ae1[k] - ae2[k][j]);
}
const float w = sigmoid(-s);
const float wu = w * wvec[0][j];
const float wv = w * wvec[1][j];
const float wz = w * wvec[2][j];
pinhole_jacobians(p, fx, fy, Ju, Jv, Jz);
for (int ii=0; ii<6; ii++) {
b[ii] += wu*ru*Ju[ii] + wv*rv*Jv[ii] + wz*rz*Jz[ii];
}
for (int ii=0; ii<6; ii++) {
for (int jj=0; jj<6; jj++) {
H[ii][jj] += wu*Ju[ii]*Ju[jj] + wv*Jv[ii]*Jv[jj] + wz*Jz[ii]*Jz[jj];
}
}
}
}
__syncthreads();
}
if (ix < dim) {
for (int ii=0; ii<6; ii++) {
bx[batch_id][ii][0][h1][w1] = b[ii];
}
for (int ii=0; ii<6; ii++) {
for (int jj=0; jj<6; jj++) {
Hx[batch_id][ii][jj][h1][w1] = H[ii][jj];
}
}
}
}
__global__ void dense_se3_backward_kernel1(
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> transforms,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> embeddings,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> points,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> targets,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> weights,
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> intrinsics,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> Hx_grad,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> bx_grad,
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> embedding_grad,
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> targets_grad,
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> weights_grad,
int radius)
{
int batch_id = blockIdx.x; // batch_index
int tx = threadIdx.x;
int ix = blockIdx.y * NUM_THREADS + tx; // image_index
const int ht = transforms.size(2);
const int wd = transforms.size(3);
const int dim = ht * wd;
const int ae_dim = embeddings.size(1);
int h2 = ix / wd;
int w2 = ix % wd;
const float* transform_data = transforms[batch_id].data();
const float* ae_data = embeddings[batch_id].data();
const float* diffH_data = Hx_grad[batch_id].data();
const float* diffb_data = bx_grad[batch_id].data();
__shared__ float fx, fy, cx, cy;
if (tx == 0) {
fx = intrinsics[batch_id][0];
fy = intrinsics[batch_id][1];
cx = intrinsics[batch_id][2];
cy = intrinsics[batch_id][3];
}
float X0[3];
float target_u, target_v, target_z;
float wu, wv, wz;
float ae2[AE_DIM];
float diff_ae2[AE_DIM];
if (ix < dim) {
X0[0] = points[batch_id][0][h2][w2];
X0[1] = points[batch_id][1][h2][w2];
X0[2] = points[batch_id][2][h2][w2];
target_u = targets[batch_id][0][h2][w2];
target_v = targets[batch_id][1][h2][w2];
target_z = targets[batch_id][2][h2][w2];
wu = weights[batch_id][0][h2][w2];
wv = weights[batch_id][1][h2][w2];
wz = weights[batch_id][2][h2][w2];
for (int ii=0; ii<ae_dim; ii++) {
ae2[ii] = ae_data[ix + ii*dim];
diff_ae2[ii] = 0;
}
}
// jacobians
float Ju[6], Jv[6], Jz[6];
float diff_ru = 0;
float diff_rv = 0;
float diff_rz = 0;
float diff_wu = 0;
float diff_wv = 0;
float diff_wz = 0;
__shared__ float Gs[NUM_THREADS][7];
__shared__ float dH[6][6][NUM_THREADS];
__shared__ float db[6][NUM_THREADS];
__shared__ float ae1[AE_DIM][NUM_THREADS];
__syncthreads();
for (int i=0; i<dim; i+=NUM_THREADS) {
int jx = i + tx;
// read from global
if (jx < dim) {
Gs[tx][0] = transform_data[jx + 0*dim];
Gs[tx][1] = transform_data[jx + 1*dim];
Gs[tx][2] = transform_data[jx + 2*dim];
Gs[tx][3] = transform_data[jx + 3*dim];
Gs[tx][4] = transform_data[jx + 4*dim];
Gs[tx][5] = transform_data[jx + 5*dim];
Gs[tx][6] = transform_data[jx + 6*dim];
for (int ii=0; ii<ae_dim; ii++) {
ae1[ii][tx] = ae_data[jx + ii*dim];
}
for (int ii=0; ii<6; ii++) {
for (int jj=0; jj<6; jj++) {
dH[ii][jj][tx] = diffH_data[jx + (ii*6+jj)*dim];
}
}
for (int ii=0; ii<6; ii++) {
db[ii][tx] = diffb_data[jx + ii*dim];
}
}
__syncthreads();
for (int j=0; j<NUM_THREADS; j++) {
jx = i + j;
if (ix<dim && jx<dim) {
int h1 = jx / wd;
int w1 = jx % wd;
int r = max(abs(h1-h2), abs(w1-w2));
if (r > radius) continue;
float p[3] = { X0[0], X0[1], X0[2] };
se3_transform_point_inplace(&Gs[j][0], p);
// residual vectors
const float X1=p[0], Y1=p[1], Z1=p[2];
const float u = fx * (X1 / Z1) + cx;
const float v = fy * (Y1 / Z1) + cy;
const float ru = target_u - u;
const float rv = target_v - v;
const float rz = target_z - 1.0 / Z1;
float s=0.0;
for (int k=0; k<ae_dim; k++) {
s += (ae1[k][j] - ae2[k]) * (ae1[k][j] - ae2[k]);
}
float diff_w = 0.0f;
const float w = sigmoid(-s);
// exclude pixels too close or errors too big
if (Z1 < 0.1 || abs(ru) > 128 || abs(rv) > 128)
continue;
pinhole_jacobians(p, fx, fy, Ju, Jv, Jz);
for (int ii=0; ii<6; ii++) {
const float db_i = db[ii][j];
// residual gradients
diff_ru += w*wu*Ju[ii] * db_i;
diff_rv += w*wv*Jv[ii] * db_i;
diff_rz += w*wz*Jz[ii] * db_i;
// weights gradients
diff_wu += w*ru*Ju[ii] * db_i;
diff_wv += w*rv*Jv[ii] * db_i;
diff_wz += w*rz*Jz[ii] * db_i;
// embedding weight
diff_w += (wu*ru*Ju[ii] + wv*rv*Jv[ii] + wz*rz*Jz[ii]) * db_i;
for (int jj=0; jj<6; jj++) {
const float dH_ij = dH[ii][jj][j];
diff_wu += w*Ju[ii]*Ju[jj] * dH_ij;
diff_wv += w*Jv[ii]*Jv[jj] * dH_ij;
diff_wz += w*Jz[ii]*Jz[jj] * dH_ij;
diff_w += (wu*Ju[ii]*Ju[jj] + wv*Jv[ii]*Jv[jj] + wz*Jz[ii]*Jz[jj]) * dH_ij;
}
}
float diff_s = -diff_w * sigmoid(-s) * (1.0f - sigmoid(-s));
for (int k=0; k<ae_dim; k++) {
diff_ae2[k] += -2 * diff_s * (ae1[k][j] - ae2[k]);
}
}
}
__syncthreads();
}
if (ix < dim) {
targets_grad[batch_id][0][h2][w2] = diff_ru;
targets_grad[batch_id][1][h2][w2] = diff_rv;
targets_grad[batch_id][2][h2][w2] = diff_rz;
weights_grad[batch_id][0][h2][w2] = diff_wu;
weights_grad[batch_id][1][h2][w2] = diff_wv;
weights_grad[batch_id][2][h2][w2] = diff_wz;
for (int k=0; k<ae_dim; k++)
embedding_grad[batch_id][k][h2][w2] += diff_ae2[k];
}
}
__global__ void dense_se3_backward_kernel2(
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> transforms,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> embeddings,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> points,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> targets,
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> weights,
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> intrinsics,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> Hx_grad,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> bx_grad,
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> embedding_grad,
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> targets_grad,
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> weights_grad,
int radius) {
int batch_id = blockIdx.x; // batch_index
int tx = threadIdx.x;
int ix = blockIdx.y * NUM_THREADS + tx; // image_index
const int ht = transforms.size(2);
const int wd = transforms.size(3);
const int ae_dim = embeddings.size(1);
const int dim = ht * wd;
const int h1 = ix / wd;
const int w1 = ix % wd;
const float* transform_data = transforms[batch_id].data();
const float* Xdata = points[batch_id].data();
const float* rdata = targets[batch_id].data();
const float* wdata = weights[batch_id].data();
const float* ae_data = embeddings[batch_id].data();
__shared__ float fx, fy, cx, cy;
if (tx == 0) {
fx = intrinsics[batch_id][0];
fy = intrinsics[batch_id][1];
cx = intrinsics[batch_id][2];
cy = intrinsics[batch_id][3];
}
// transformation
float G[7];
float ae1[AE_DIM];
float diff_ae1[AE_DIM];
float db[6], dH[6][6];
if (ix < dim) {
G[0] = transform_data[ix + 0*dim]; // tx
G[1] = transform_data[ix + 1*dim]; // ty
G[2] = transform_data[ix + 2*dim]; // tz
G[3] = transform_data[ix + 3*dim]; // qx
G[4] = transform_data[ix + 4*dim]; // qy
G[5] = transform_data[ix + 5*dim]; // qz
G[6] = transform_data[ix + 6*dim]; // qw
for (int ii=0; ii<ae_dim; ii++) {
ae1[ii] = embeddings[batch_id][ii][h1][w1];
diff_ae1[ii] = 0;
}
for (int ii=0; ii<6; ii++) {
db[ii] = bx_grad[batch_id][ii][0][h1][w1];
}
for (int ii=0; ii<6; ii++) {
for (int jj=0; jj<6; jj++) {
dH[ii][jj] = Hx_grad[batch_id][ii][jj][h1][w1];
}
}
}
// jacobians
float Ju[6], Jv[6], Jz[6];
__shared__ float X0[3][NUM_THREADS];
__shared__ float ae2[AE_DIM][NUM_THREADS];
__shared__ float rvec[3][NUM_THREADS];
__shared__ float wvec[3][NUM_THREADS];
__syncthreads();
for (int i=0; i<dim; i+=NUM_THREADS) {
// load in data
int jx = i + tx;
if (jx < dim) {
X0[0][tx] = Xdata[jx+0*dim];
X0[1][tx] = Xdata[jx+1*dim];
X0[2][tx] = Xdata[jx+2*dim];
rvec[0][tx] = rdata[jx+0*dim];
rvec[1][tx] = rdata[jx+1*dim];
rvec[2][tx] = rdata[jx+2*dim];
wvec[0][tx] = wdata[jx+0*dim];
wvec[1][tx] = wdata[jx+1*dim];
wvec[2][tx] = wdata[jx+2*dim];
for (int k=0; k<ae_dim; k++)
ae2[k][tx] = ae_data[jx + k*dim];
}
__syncthreads();
for (int j=0; j<NUM_THREADS; j++) {
jx = i + j;
if (ix<dim && jx<dim) {
int h2 = jx / wd;
int w2 = jx % wd;
int r = max(abs(h1-h2), abs(w1-w2));
if (r > radius) continue;
float p[3] = { X0[0][j], X0[1][j], X0[2][j] };
se3_transform_point_inplace(G, p);
// residual vectors
const float X1=p[0], Y1=p[1], Z1=p[2];
const float u = fx * (X1 / Z1) + cx;
const float v = fy * (Y1 / Z1) + cy;
const float ru = rvec[0][j] - u;
const float rv = rvec[1][j] - v;
const float rz = rvec[2][j] - 1.0 / Z1;
float s=0.0;
for (int k=0; k<ae_dim; k++) {
s += (ae1[k] - ae2[k][j]) * (ae1[k] - ae2[k][j]);
}
const float w = sigmoid(-s);
float diff_w = 0;
const float wu = wvec[0][j];
const float wv = wvec[1][j];
const float wz = wvec[2][j];
// exclude pixels too close or errors too big
if (Z1 < 0.1 || abs(ru) > 128 || abs(rv) > 128)
continue;
pinhole_jacobians(p, fx, fy, Ju, Jv, Jz);
for (int ii=0; ii<6; ii++) {
diff_w += (wu*ru*Ju[ii] + wv*rv*Jv[ii] + wz*rz*Jz[ii]) * db[ii];
for (int jj=0; jj<6; jj++) {
diff_w += (wu*Ju[ii]*Ju[jj] + wv*Jv[ii]*Jv[jj] + wz*Jz[ii]*Jz[jj]) * dH[ii][jj];
}
}
float diff_s = -diff_w * sigmoid(-s) * (1.0f - sigmoid(-s));
for (int k=0; k<ae_dim; k++) {
diff_ae1[k] += 2 * diff_s * (ae1[k] - ae2[k][j]);
}
}
}
__syncthreads();
}
if (ix < dim) {
for (int k=0; k<ae_dim; k++)
embedding_grad[batch_id][k][h1][w1] += diff_ae1[k];
}
}
std::vector<torch::Tensor> dense_se3_forward_cuda(
torch::Tensor transforms,
torch::Tensor embeddings,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
int radius)
{
int batch_size = transforms.size(0);
int ht = transforms.size(2);
int wd = transforms.size(3);
dim3 grid = dim3(batch_size, (ht*wd + NUM_THREADS-1) / NUM_THREADS);
auto opts = targets.options();
torch::Tensor H = torch::zeros({batch_size, 6, 6, ht, wd}, opts);
torch::Tensor b = torch::zeros({batch_size, 6, 1, ht, wd}, opts);
dense_se3_forward_kernel<<<grid, NUM_THREADS>>>(
transforms.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
embeddings.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
points.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
targets.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
weights.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
intrinsics.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
H.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
b.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
radius);
return {H, b};
}
std::vector<torch::Tensor> dense_se3_backward_cuda(
torch::Tensor transforms,
torch::Tensor embeddings,
torch::Tensor points,
torch::Tensor targets,
torch::Tensor weights,
torch::Tensor intrinsics,
torch::Tensor H_grad,
torch::Tensor b_grad,
int radius)
{
int batch_size = transforms.size(0);
int ht = transforms.size(2);
int wd = transforms.size(3);
dim3 grid = dim3(batch_size, (ht*wd + NUM_THREADS-1) / NUM_THREADS);
torch::Tensor embedding_grad = torch::zeros_like(embeddings);
torch::Tensor targets_grad = torch::zeros_like(targets);
torch::Tensor weights_grad = torch::zeros_like(weights);
// backward pass split into two kernels to avoid atomics
dense_se3_backward_kernel1<<<grid, NUM_THREADS>>>(
transforms.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
embeddings.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
points.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
targets.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
weights.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
intrinsics.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
H_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
b_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
embedding_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
targets_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
weights_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
radius);
dense_se3_backward_kernel2<<<grid, NUM_THREADS>>>(
transforms.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
embeddings.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
points.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
targets.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
weights.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
intrinsics.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
H_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
b_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
embedding_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
targets_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
weights_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
radius);
return {embedding_grad, targets_grad, weights_grad};
}
\ No newline at end of file
#include <torch/extension.h>
#include <vector>
#define NUM_THREADS 64
#define EPS 1e-8
template <int N>
__device__ __forceinline__ void llt(const float A[N][N], float L[N][N])
{
for (int i=0; i<N; i++) {
for (int j=0; j<N; j++) {
L[i][j] = 0;
}
}
float s;
for (int i=0; i<N; i++) {
for (int j=0; j<(i+1); j++) {
s = 0.0;
for (int k=0; k<j; k++)
s += L[i][k] * L[j][k];
if (i==j) {
s = s > A[i][i] ? A[i][i] + EPS : s;
L[i][j] = sqrtf(A[i][i]-s);
}
else
L[i][j] = (A[i][j] - s) / L[j][j];
}
}
}
template <int N>
__device__ __forceinline__ void llt_solve(const float L[N][N], float x[N])
{
float s;
for (int i=0; i<N; i++) {
s = 0.0;
for (int j=0; j<i; j++)
s += L[i][j] * x[j];
x[i] = (x[i] - s) / L[i][i];
}
for (int i=N-1; i>=0; i--) {
s = 0.0;
for (int j=i+1; j<N; j++)
s += L[j][i] * x[j];
x[i] = (x[i] - s) / L[i][i];
}
}
__global__ void cholesky_solve6x6_forward_kernel(
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> H_tensor,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> b_tensor,
torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> x_tensor) {
/*Inputs: H [batch,6,6,ht,wd], b [batch,6,1,ht,wd]
Outputs: x [batch,6,1,ht,wd]; Hx = b
*/
int batch_id = blockIdx.x;
const int dim = H_tensor.size(3) * H_tensor.size(4);
int m = blockIdx.y * NUM_THREADS + threadIdx.x;
const float* H_ptr = H_tensor[batch_id].data();
const float* b_ptr = b_tensor[batch_id].data();
float* x_ptr = x_tensor[batch_id].data();
if (m < dim) {
float H[6][6], L[6][6], x[6];
for (int i=0; i<6; i++) {
for (int j=0; j<6; j++) {
H[i][j] = H_ptr[m + (6*i+j)*dim];
}
}
for (int i=0; i<6; i++) {
x[i] = b_ptr[m + i*dim];
}
llt<6>(H, L);
llt_solve<6>(L, x);
for (int i=0; i<6; i++) {
x_ptr[m + i*dim] = x[i];
}
}
}
__global__ void cholesky_solve6x6_backward_kernel(
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> H_tensor,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> b_tensor,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> dx_tensor,
torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> dH_tensor,
torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> db_tensor) {
int batch_id = blockIdx.x;
const int dim = H_tensor.size(3) * H_tensor.size(4);
int m = blockIdx.y * NUM_THREADS + threadIdx.x;
const float* H_ptr = H_tensor[batch_id].data();
const float* b_ptr = b_tensor[batch_id].data();
const float* dx_ptr = dx_tensor[batch_id].data();
float* dH_ptr = dH_tensor[batch_id].data();
float* db_ptr = db_tensor[batch_id].data();
if (m < dim) {
float H[6][6], L[6][6], x[6], dz[6];
for (int i=0; i<6; i++) {
for (int j=0; j<6; j++) {
H[i][j] = H_ptr[m + (6*i+j)*dim];
}
}
for (int i=0; i<6; i++) {
x[i] = b_ptr[m + i*dim];
}
for (int i=0; i<6; i++) {
dz[i] = dx_ptr[m + i*dim];
}
// cholesky factorization
llt<6>(H, L);
llt_solve<6>(L, x);
llt_solve<6>(L, dz);
for (int i=0; i<6; i++) {
for (int j=0; j<6; j++) {
dH_ptr[m + (6*i+j)*dim] = -dz[i] * x[j];
}
}
for (int i=0; i<6; i++) {
db_ptr[m + i*dim] = dz[i];
}
}
}
std::vector<torch::Tensor> cholesky_solve6x6_forward_cuda(torch::Tensor H, torch::Tensor b) {
const int batch_size = H.size(0);
const int ht = H.size(3);
const int wd = H.size(4);
torch::Tensor x = torch::zeros_like(b);
dim3 grid = dim3(batch_size, (ht*wd + NUM_THREADS-1) / NUM_THREADS);
cholesky_solve6x6_forward_kernel<<<grid, NUM_THREADS>>>(
H.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
b.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
x.packed_accessor32<float,5,torch::RestrictPtrTraits>());
return {x};
}
std::vector<torch::Tensor> cholesky_solve6x6_backward_cuda(torch::Tensor H, torch::Tensor b, torch::Tensor dx) {
const int batch_size = H.size(0);
const int ht = H.size(3);
const int wd = H.size(4);
torch::Tensor dH = torch::zeros_like(H);
torch::Tensor db = torch::zeros_like(b);
dim3 grid = dim3(batch_size, (ht*wd + NUM_THREADS-1) / NUM_THREADS);
cholesky_solve6x6_backward_kernel<<<grid, NUM_THREADS>>>(
H.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
b.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
dx.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
dH.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
db.packed_accessor32<float,5,torch::RestrictPtrTraits>());
return {dH, db};
}
\ No newline at end of file
import torch
from torch.types import _TensorOrTensors
from torch._six import container_abcs, istuple
import torch.testing
from torch.overrides import is_tensor_like
from itertools import product
import warnings
from typing import Callable, Union, Optional, Iterable, List
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, container_abcs.Iterable):
for elem in x:
zero_gradients(elem)
def make_jacobian(input, num_out):
if is_tensor_like(input):
if not input.is_floating_point() and not input.is_complex():
return None
if not input.requires_grad:
return None
return input.new_zeros((input.nelement(), num_out), dtype=input.dtype, layout=torch.strided)
elif isinstance(input, container_abcs.Iterable) and not isinstance(input, str):
jacobians = list(filter(
lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
if not jacobians:
return None
return type(input)(jacobians) # type: ignore
else:
return None
def iter_tensors(x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False) -> Iterable[torch.Tensor]:
if is_tensor_like(x):
# mypy doesn't narrow type of `x` to torch.Tensor
if x.requires_grad or not only_requiring_grad: # type: ignore
yield x # type: ignore
elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
for elem in x:
for result in iter_tensors(elem, only_requiring_grad):
yield result
def get_numerical_jacobian(fn, input, target=None, eps=1e-3, grad_out=1.0):
"""
input: input to `fn`
target: the Tensors wrt whom Jacobians are calculated (default=`input`)
grad_out: grad output value used to calculate gradients.
Note that `target` may not even be part of `input` to `fn`, so please be
**very careful** in this to not clone `target`.
"""
if target is None:
target = input
output_size = fn(input).numel()
jacobian = make_jacobian(target, output_size)
# It's much easier to iterate over flattened lists of tensors.
# These are reference to the same objects in jacobian, so any changes
# will be reflected in it as well.
x_tensors = iter_tensors(target, True)
j_tensors = iter_tensors(jacobian)
def update_jacobians(x, idx, d, d_idx, is_mkldnn=False):
# compute_jacobian only works for pure real
# or pure imaginary delta
def compute_gradient(delta):
# we currently assume that the norm of delta equals eps
assert(delta == eps or delta == (eps * 1j))
def fn_out():
if not is_mkldnn:
# x is a view into input and so this works
return fn(input).clone()
else:
# convert the dense tensor back to have mkldnn layout
return fn([x.to_mkldnn()])
orig = x[idx].item()
x[idx] = orig - delta
outa = fn_out()
x[idx] = orig + delta
outb = fn_out()
x[idx] = orig
r = (outb - outa) / (2 * eps)
return r.detach().reshape(-1)
# for details on the algorithm used here, refer:
# Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
# s = fn(z) where z = x for real valued input
# and z = x + yj for complex valued input
ds_dx = compute_gradient(eps)
if x.is_complex(): # C -> C, C -> R
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()
elif ds_dx.is_complex(): # R -> C
# w_d = conj_w_d = 0.5 * ds_dx
# dL_dz_conj = 0.5 * [grad_out.conj() * ds_dx + grad_out * ds_dx.conj()]
# = 0.5 * [grad_out.conj() * ds_dx + (grad_out.conj() * ds_dx).conj()]
# = 0.5 * 2 * real(grad_out.conj() * ds_dx)
# = real(grad_out.conj() * ds_dx)
d[d_idx] = torch.real(grad_out.conjugate() * ds_dx)
else: # R -> R
d[d_idx] = ds_dx * grad_out
# TODO: compare structure
for x_tensor, d_tensor in zip(x_tensors, j_tensors):
if x_tensor.is_sparse:
def get_stride(size):
dim = len(size)
tmp = 1
stride = [0] * dim
for i in reversed(range(dim)):
stride[i] = tmp
tmp *= size[i]
return stride
x_nnz = x_tensor._nnz()
x_size = list(x_tensor.size())
x_indices = x_tensor._indices().t()
x_values = x_tensor._values()
x_stride = get_stride(x_size)
# Use .data here to get around the version check
x_values = x_values.data
for i in range(x_nnz):
x_value = x_values[i]
for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
indices = x_indices[i].tolist() + list(x_idx)
d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
update_jacobians(x_value, x_idx, d_tensor, d_idx)
elif x_tensor.layout == torch._mkldnn: # type: ignore
# Use .data here to get around the version check
x_tensor = x_tensor.data
if len(input) != 1:
raise ValueError('gradcheck currently only supports functions with 1 input, but got: ',
len(input))
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
# this is really inefficient, but without indexing implemented, there's
# not really a better way than converting back and forth
x_tensor_dense = x_tensor.to_dense()
update_jacobians(x_tensor_dense, x_idx, d_tensor, d_idx, is_mkldnn=True)
else:
# Use .data here to get around the version check
x_tensor = x_tensor.data
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
update_jacobians(x_tensor, x_idx, d_tensor, d_idx)
return jacobian
def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0):
# it is easier to call to_dense() on the sparse output than
# to modify analytical jacobian
if output.is_sparse:
raise ValueError('Sparse output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.')
if output.layout == torch._mkldnn: # type: ignore
raise ValueError('MKLDNN output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.')
diff_input_list = list(iter_tensors(input, True))
jacobian = make_jacobian(input, output.numel())
jacobian_reentrant = make_jacobian(input, output.numel())
grad_output = torch.zeros_like(output, memory_format=torch.legacy_contiguous_format)
flat_grad_output = grad_output.view(-1)
reentrant = True
correct_grad_sizes = True
correct_grad_types = True
for i in range(flat_grad_output.numel()):
flat_grad_output.zero_()
flat_grad_output[i] = grad_out
for jacobian_c in (jacobian, jacobian_reentrant):
grads_input = torch.autograd.grad(output, diff_input_list, grad_output,
retain_graph=True, allow_unused=True)
for jacobian_x, d_x, x in zip(jacobian_c, grads_input, diff_input_list):
if d_x is not None and d_x.size() != x.size():
correct_grad_sizes = False
elif d_x is not None and d_x.dtype != x.dtype:
correct_grad_types = False
elif jacobian_x.numel() != 0:
if d_x is None:
jacobian_x[:, i].zero_()
else:
d_x_dense = d_x.to_dense() if not d_x.layout == torch.strided else d_x
assert jacobian_x[:, i].numel() == d_x_dense.numel()
jacobian_x[:, i] = d_x_dense.contiguous().view(-1)
for jacobian_x, jacobian_reentrant_x in zip(jacobian, jacobian_reentrant):
if jacobian_x.numel() != 0 and (jacobian_x - jacobian_reentrant_x).abs().max() > nondet_tol:
reentrant = False
return jacobian, reentrant, correct_grad_sizes, correct_grad_types
def _as_tuple(x):
if istuple(x):
return x
elif isinstance(x, list):
return tuple(x)
else:
return x,
def _differentiable_outputs(x):
return tuple(o for o in _as_tuple(x) if o.requires_grad)
# Note [VarArg of Tensors]
# ~~~~~~~~~~~~~~~~~~~~~~~~
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
# For now, we permit any input.
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
# For now, we permit any input.
def gradcheck(
func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors]
inputs: _TensorOrTensors,
eps: float = 1e-6,
atol: float = 1e-5,
rtol: float = 1e-3,
raise_exception: bool = True,
check_sparse_nnz: bool = False,
nondet_tol: float = 0.0,
check_undefined_grad: bool = True,
check_grad_dtypes: bool = False
) -> bool:
r"""Check gradients computed via small finite differences against analytical
gradients w.r.t. tensors in :attr:`inputs` that are of floating point or complex type
and with ``requires_grad=True``.
The check between numerical and analytical gradients uses :func:`~torch.allclose`.
For complex functions, no notion of Jacobian exists. Gradcheck verifies if the numerical and
analytical values of Wirtinger and Conjugate Wirtinger derivative are consistent. The gradient
computation is done under the assumption that the overall function has a real valued output.
For functions with complex output, gradcheck compares the numerical and analytical gradients
for two values of :attr:`grad_output`: 1 and 1j. For more details, check out
:ref:`complex_autograd-doc`.
.. note::
The default values are designed for :attr:`input` of double precision.
This check will likely fail if :attr:`input` is of less precision, e.g.,
``FloatTensor``.
.. warning::
If any checked tensor in :attr:`input` has overlapping memory, i.e.,
different indices pointing to the same memory address (e.g., from
:func:`torch.expand`), this check will likely fail because the numerical
gradients computed by point perturbation at such indices will change
values at all other indices that share the same memory address.
Args:
func (function): a Python function that takes Tensor inputs and returns
a Tensor or a tuple of Tensors
inputs (tuple of Tensor or Tensor): inputs to the function
eps (float, optional): perturbation for finite differences
atol (float, optional): absolute tolerance
rtol (float, optional): relative tolerance
raise_exception (bool, optional): indicating whether to raise an exception if
the check fails. The exception gives more information about the
exact nature of the failure. This is helpful when debugging gradchecks.
check_sparse_nnz (bool, optional): if True, gradcheck allows for SparseTensor input,
and for any SparseTensor at input, gradcheck will perform check at nnz positions only.
nondet_tol (float, optional): tolerance for non-determinism. When running
identical inputs through the differentiation, the results must either match
exactly (default, 0.0) or be within this tolerance.
check_undefined_grad (bool, options): if True, check if undefined output grads
are supported and treated as zeros, for ``Tensor`` outputs.
Returns:
True if all differences satisfy allclose condition
"""
def fail_test(msg):
if raise_exception:
raise RuntimeError(msg)
return False
tupled_inputs = _as_tuple(inputs)
if not check_sparse_nnz and any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)):
return fail_test('gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False.')
# Make sure that gradients are saved for at least one input
any_input_requiring_grad = False
for idx, inp in enumerate(tupled_inputs):
if is_tensor_like(inp) and inp.requires_grad:
if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128):
warnings.warn(
f'Input #{idx} requires gradient and '
'is not a double precision floating point or complex. '
'This check will likely fail if all the inputs are '
'not of double precision floating point or complex. ')
content = inp._values() if inp.is_sparse else inp
# TODO: To cover more problematic cases, replace stride = 0 check with
# "any overlap in memory" once we have a proper function to check it.
if content.layout is not torch._mkldnn: # type: ignore
if not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())):
raise RuntimeError(
'The {}th input has a dimension with stride 0. gradcheck only '
'supports inputs that are non-overlapping to be able to '
'compute the numerical gradients correctly. You should call '
'.contiguous on the input before passing it to gradcheck.')
any_input_requiring_grad = True
inp.retain_grad()
if not any_input_requiring_grad:
raise ValueError(
'gradcheck expects at least one input tensor to require gradient, '
'but none of the them have requires_grad=True.')
func_out = func(*tupled_inputs)
output = _differentiable_outputs(func_out)
if not output:
for i, o in enumerate(func_out):
def fn(input):
return _as_tuple(func(*input))[i]
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
for n in numerical:
if torch.ne(n, 0).sum() > 0:
return fail_test('Numerical gradient for function expected to be zero')
return True
for i, o in enumerate(output):
if not o.requires_grad:
continue
def fn(input):
return _as_tuple(func(*input))[i]
analytical, reentrant, correct_grad_sizes, correct_grad_types = get_analytical_jacobian(tupled_inputs,
o,
nondet_tol=nondet_tol)
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
return analytical, numerical
out_is_complex = o.is_complex()
if out_is_complex:
# analytical vjp with grad_out = 1.0j
analytical_with_imag_grad_out, reentrant_with_imag_grad_out, \
correct_grad_sizes_with_imag_grad_out, correct_grad_types_with_imag_grad_out \
= get_analytical_jacobian(tupled_inputs, o, nondet_tol=nondet_tol, grad_out=1j)
numerical_with_imag_grad_out = get_numerical_jacobian(fn, tupled_inputs, eps=eps, grad_out=1j)
if not correct_grad_types and check_grad_dtypes:
return fail_test('Gradient has dtype mismatch')
if out_is_complex and not correct_grad_types_with_imag_grad_out and check_grad_dtypes:
return fail_test('Gradient (calculated using complex valued grad output) has dtype mismatch')
if not correct_grad_sizes:
return fail_test('Analytical gradient has incorrect size')
if out_is_complex and not correct_grad_sizes_with_imag_grad_out:
return fail_test('Analytical gradient (calculated using complex valued grad output) has incorrect size')
def checkIfNumericalAnalyticAreClose(a, n, j, error_str=''):
if not torch.allclose(a, n, rtol, atol):
return fail_test(error_str + 'Jacobian mismatch for output %d with respect to input %d,\n'
'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
inp_tensors = iter_tensors(tupled_inputs, True)
for j, (a, n, inp) in enumerate(zip(analytical, numerical, inp_tensors)):
if a.numel() != 0 or n.numel() != 0:
if o.is_complex():
# C -> C, R -> C
a_with_imag_grad_out = analytical_with_imag_grad_out[j]
n_with_imag_grad_out = numerical_with_imag_grad_out[j]
checkIfNumericalAnalyticAreClose(a_with_imag_grad_out, n_with_imag_grad_out, j,
"Gradients failed to compare equal for grad output = 1j. ")
if inp.is_complex():
# C -> R, C -> C
checkIfNumericalAnalyticAreClose(a, n, j,
"Gradients failed to compare equal for grad output = 1. ")
else:
# R -> R, R -> C
checkIfNumericalAnalyticAreClose(a, n, j)
def not_reentrant_error(error_str=''):
error_msg = "Backward" + error_str + " is not reentrant, i.e., running backward with same \
input and grad_output multiple times gives different values, \
although analytical gradient matches numerical gradient. \
The tolerance for nondeterminism was {}.".format(nondet_tol)
return fail_test(error_msg)
if not reentrant:
return not_reentrant_error()
if out_is_complex and not reentrant_with_imag_grad_out:
return not_reentrant_error(' (calculated using complex valued grad output)')
# check if the backward multiplies by grad_output
output = _differentiable_outputs(func(*tupled_inputs))
if any([o.requires_grad for o in output]):
diff_input_list: List[torch.Tensor] = list(iter_tensors(tupled_inputs, True))
if not diff_input_list:
raise RuntimeError("no Tensors requiring grad found in input")
grads_input = torch.autograd.grad(output, diff_input_list,
[torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output],
allow_unused=True)
for gi, di in zip(grads_input, diff_input_list):
if gi is None:
continue
if isinstance(gi, torch.Tensor) and gi.layout != torch.strided:
if gi.layout != di.layout:
return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(di.layout) + ')')
if gi.layout == torch.sparse_coo:
if gi.sparse_dim() != di.sparse_dim():
return fail_test('grad is sparse tensor, but has incorrect sparse_dim')
if gi.dense_dim() != di.dense_dim():
return fail_test('grad is sparse tensor, but has incorrect dense_dim')
gi = gi.to_dense()
di = di.to_dense()
if not gi.eq(0).all():
return fail_test('backward not multiplied by grad_output')
if gi.dtype != di.dtype or gi.device != di.device or gi.is_sparse != di.is_sparse:
return fail_test("grad is incorrect type")
if gi.size() != di.size():
return fail_test('grad is incorrect size')
if check_undefined_grad:
def warn_bc_breaking():
warnings.warn((
'Backwards compatibility: New undefined gradient support checking '
'feature is enabled by default, but it may break existing callers '
'of this function. If this is true for you, you can call this '
'function with "check_undefined_grad=False" to disable the feature'))
def check_undefined_grad_support(output_to_check):
grads_output = [torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output_to_check]
try:
grads_input = torch.autograd.grad(output_to_check,
diff_input_list,
grads_output,
allow_unused=True)
except RuntimeError:
warn_bc_breaking()
return fail_test((
'Expected backward function to handle undefined output grads. '
'Please look at "Notes about undefined output gradients" in '
'"tools/autograd/derivatives.yaml"'))
for gi, i in zip(grads_input, diff_input_list):
if (gi is not None) and (not gi.eq(0).all()):
warn_bc_breaking()
return fail_test((
'Expected all input grads to be undefined or zero when all output grads are undefined '
'or zero. Please look at "Notes about undefined output gradients" in '
'"tools/autograd/derivatives.yaml"'))
return True
# All backward functions must work properly if all output grads are undefined
outputs_to_check = [[
torch._C._functions.UndefinedGrad()(o) for o in _differentiable_outputs(func(*tupled_inputs))
# This check filters out Tensor-likes that aren't instances of Tensor.
if isinstance(o, torch.Tensor)
]]
# If there are multiple output grads, we should be able to undef one at a time without error
if len(outputs_to_check[0]) > 1:
for undef_grad_idx in range(len(output)):
output_to_check = _differentiable_outputs(func(*tupled_inputs))
outputs_to_check.append([
torch._C._functions.UndefinedGrad()(o) if idx == undef_grad_idx else o
for idx, o in enumerate(output_to_check)])
for output_to_check in outputs_to_check:
if not check_undefined_grad_support(output_to_check):
return False
return True
def gradgradcheck(
func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors]
inputs: _TensorOrTensors,
grad_outputs: Optional[_TensorOrTensors] = None,
eps: float = 1e-6,
atol: float = 1e-5,
rtol: float = 1e-3,
gen_non_contig_grad_outputs: bool = False,
raise_exception: bool = True,
nondet_tol: float = 0.0,
check_undefined_grad: bool = True,
check_grad_dtypes: bool = False
) -> bool:
r"""Check gradients of gradients computed via small finite differences
against analytical gradients w.r.t. tensors in :attr:`inputs` and
:attr:`grad_outputs` that are of floating point or complex type and with
``requires_grad=True``.
This function checks that backpropagating through the gradients computed
to the given :attr:`grad_outputs` are correct.
The check between numerical and analytical gradients uses :func:`~torch.allclose`.
.. note::
The default values are designed for :attr:`input` and
:attr:`grad_outputs` of double precision. This check will likely fail if
they are of less precision, e.g., ``FloatTensor``.
.. warning::
If any checked tensor in :attr:`input` and :attr:`grad_outputs` has
overlapping memory, i.e., different indices pointing to the same memory
address (e.g., from :func:`torch.expand`), this check will likely fail
because the numerical gradients computed by point perturbation at such
indices will change values at all other indices that share the same
memory address.
Args:
func (function): a Python function that takes Tensor inputs and returns
a Tensor or a tuple of Tensors
inputs (tuple of Tensor or Tensor): inputs to the function
grad_outputs (tuple of Tensor or Tensor, optional): The gradients with
respect to the function's outputs.
eps (float, optional): perturbation for finite differences
atol (float, optional): absolute tolerance
rtol (float, optional): relative tolerance
gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is
``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the
randomly generated gradient outputs are made to be noncontiguous
raise_exception (bool, optional): indicating whether to raise an exception if
the check fails. The exception gives more information about the
exact nature of the failure. This is helpful when debugging gradchecks.
nondet_tol (float, optional): tolerance for non-determinism. When running
identical inputs through the differentiation, the results must either match
exactly (default, 0.0) or be within this tolerance. Note that a small amount
of nondeterminism in the gradient will lead to larger inaccuracies in
the second derivative.
check_undefined_grad (bool, options): if True, check if undefined output grads
are supported and treated as zeros
Returns:
True if all differences satisfy allclose condition
"""
tupled_inputs = _as_tuple(inputs)
if grad_outputs is None:
# If grad_outputs is not specified, create random Tensors of the same
# shape, type, and device as the outputs
def randn_like(x):
y = torch.testing.randn_like(
x if (x.is_floating_point() or x.is_complex()) else x.double(), memory_format=torch.legacy_contiguous_format)
if gen_non_contig_grad_outputs:
y = torch.testing.make_non_contiguous(y)
return y.requires_grad_()
outputs = _as_tuple(func(*tupled_inputs))
tupled_grad_outputs = tuple(randn_like(x) for x in outputs)
else:
tupled_grad_outputs = _as_tuple(grad_outputs)
num_outputs = len(tupled_grad_outputs)
def new_func(*args):
input_args = args[:-num_outputs]
grad_outputs = args[-num_outputs:]
outputs = _differentiable_outputs(func(*input_args))
input_args = tuple(x for x in input_args if isinstance(x, torch.Tensor) and x.requires_grad)
grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True)
return grad_inputs
return gradcheck(new_func, tupled_inputs + tupled_grad_outputs, eps, atol, rtol, raise_exception,
nondet_tol=nondet_tol, check_undefined_grad=check_undefined_grad,
check_grad_dtypes=check_grad_dtypes)
import lietorch_backends
import torch
import torch.nn.functional as F
class GroupOp(torch.autograd.Function):
""" group operation base class """
@classmethod
def forward(cls, ctx, group_id, *inputs):
ctx.group_id = group_id
ctx.save_for_backward(*inputs)
out = cls.forward_op(ctx.group_id, *inputs)
return out
@classmethod
def backward(cls, ctx, grad):
error_str = "Backward operation not implemented for {}".format(cls)
assert cls.backward_op is not None, error_str
inputs = ctx.saved_tensors
grad = grad.contiguous()
grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs)
return (None, ) + tuple(grad_inputs)
class Exp(GroupOp):
""" exponential map """
forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward
class Log(GroupOp):
""" logarithm map """
forward_op, backward_op = lietorch_backends.logm, lietorch_backends.logm_backward
class Inv(GroupOp):
""" group inverse """
forward_op, backward_op = lietorch_backends.inv, lietorch_backends.inv_backward
class Mul(GroupOp):
""" group multiplication """
forward_op, backward_op = lietorch_backends.mul, lietorch_backends.mul_backward
class Adj(GroupOp):
""" adjoint operator """
forward_op, backward_op = lietorch_backends.adj, lietorch_backends.adj_backward
class AdjT(GroupOp):
""" adjoint operator """
forward_op, backward_op = lietorch_backends.adjT, lietorch_backends.adjT_backward
class Act3(GroupOp):
""" action on point """
forward_op, backward_op = lietorch_backends.act, lietorch_backends.act_backward
class Act4(GroupOp):
""" action on point """
forward_op, backward_op = lietorch_backends.act4, lietorch_backends.act4_backward
class Jinv(GroupOp):
""" adjoint operator """
forward_op, backward_op = lietorch_backends.Jinv, None
class ToMatrix(GroupOp):
""" convert to matrix representation """
forward_op, backward_op = lietorch_backends.as_matrix, None
class ExtractTranslation(torch.autograd.Function):
""" group operation base class """
@staticmethod
def forward(ctx, data):
ctx.save_for_backward(data)
return data[...,:3]
@staticmethod
def backward(ctx, dt):
data, = ctx.saved_tensors
t = data[...,:3]
diff_tau_phi = torch.zeros_like(data)
diff_tau_phi[...,0:3] = dt
diff_tau_phi[...,3:6] = torch.cross(t, dt)
return diff_tau_phi
\ No newline at end of file
import torch
import numpy as np
# group operations implemented in cuda
from .group_ops import Exp, Log, Inv, Mul, Adj, AdjT, Jinv, Act3, Act4, ToMatrix, ExtractTranslation
from .broadcasting import broadcast_inputs
class LieGroupParameter(torch.Tensor):
""" Wrapper class for LieGroup """
from torch._C import _disabled_torch_function_impl
__torch_function__ = _disabled_torch_function_impl
def __new__(cls, group, requires_grad=True):
data = torch.zeros(group.tangent_shape,
device=group.data.device,
dtype=group.data.dtype,
requires_grad=True)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self, group):
self.group = group
def retr(self):
return self.group.retr(self)
def log(self):
return self.retr().log()
def inv(self):
return self.retr().inv()
def adj(self, a):
return self.retr().adj(a)
def __mul__(self, other):
if isinstance(other, LieGroupParameter):
return self.retr() * other.retr()
else:
return self.retr() * other
def add_(self, update, alpha):
self.group = self.group.exp(alpha*update) * self.group
def __getitem__(self, index):
return self.retr().__getitem__(index)
class LieGroup:
""" Base class for Lie Group """
def __init__(self, data):
self.data = data
def __repr__(self):
return "{}: size={}, device={}, dtype={}".format(
self.group_name, self.shape, self.device, self.dtype)
@property
def shape(self):
return self.data.shape[:-1]
@property
def device(self):
return self.data.device
@property
def dtype(self):
return self.data.dtype
@property
def tangent_shape(self):
return self.data.shape[:-1] + (self.manifold_dim,)
@classmethod
def Identity(cls, *batch_shape, **kwargs):
""" Construct identity element with batch shape """
if isinstance(batch_shape[0], tuple):
batch_shape = batch_shape[0]
elif isinstance(batch_shape[0], list):
batch_shape = tuple(batch_shape[0])
numel = np.prod(batch_shape)
data = cls.id_elem.reshape(1,-1)
if 'device' in kwargs:
data = data.to(kwargs['device'])
if 'dtype' in kwargs:
data = data.type(kwargs['dtype'])
data = data.repeat(numel, 1)
return cls(data).view(batch_shape)
@classmethod
def IdentityLike(cls, G):
return cls.Identity(G.shape, device=G.data.device, dtype=G.data.dtype)
@classmethod
def Random(cls, *batch_shape, sigma=1.0, **kwargs):
""" Construct random element with batch_shape by random sampling in tangent space"""
if isinstance(batch_shape[0], tuple):
batch_shape = batch_shape[0]
elif isinstance(batch_shape[0], list):
batch_shape = tuple(batch_shape[0])
tangent_shape = batch_shape + (cls.manifold_dim,)
xi = torch.randn(tangent_shape, **kwargs)
return cls.exp(sigma * xi)
@classmethod
def apply_op(cls, op, x, y=None):
""" Apply group operator """
inputs, out_shape = broadcast_inputs(x, y)
data = op.apply(cls.group_id, *inputs)
return data.view(out_shape + (-1,))
@classmethod
def exp(cls, x):
""" exponential map: x -> X """
return cls(cls.apply_op(Exp, x))
def log(self):
""" logarithm map """
return self.apply_op(Log, self.data)
def inv(self):
""" group inverse """
return self.__class__(self.apply_op(Inv, self.data))
def mul(self, other):
""" group multiplication """
return self.__class__(self.apply_op(Mul, self.data, other.data))
def retr(self, a):
""" retraction: Exp(a) * X """
dX = self.__class__.apply_op(Exp, a)
return self.__class__(self.apply_op(Mul, dX, self.data))
def adj(self, a):
""" adjoint operator: b = A(X) * a """
return self.apply_op(Adj, self.data, a)
def adjT(self, a):
""" transposed adjoint operator: b = a * A(X) """
return self.apply_op(AdjT, self.data, a)
def Jinv(self, a):
return self.apply_op(Jinv, self.data, a)
def act(self, p):
""" action on a point cloud """
# action on point
if p.shape[-1] == 3:
return self.apply_op(Act3, self.data, p)
# action on homogeneous point
elif p.shape[-1] == 4:
return self.apply_op(Act4, self.data, p)
def matrix(self):
""" convert element to 4x4 matrix """
input_shape = self.data.shape
mat = ToMatrix.apply(self.group_id, self.data.reshape(-1, self.embedded_dim))
return mat.view(input_shape[:-1] + (4,4))
def detach(self):
return self.__class__(self.data.detach())
def view(self, dims):
data_reshaped = self.data.view(dims + (self.embedded_dim,))
return self.__class__(data_reshaped)
def __mul__(self, other):
# group multiplication
if isinstance(other, LieGroup):
return self.mul(other)
# action on point
elif isinstance(other, torch.Tensor):
return self.act(other)
def __getitem__(self, index):
return self.__class__(self.data[index])
def __setitem__(self, index, item):
self.data[index] = item.data
def to(self, *args, **kwargs):
return self.__class__(self.data.to(*args, **kwargs))
def cpu(self):
return self.__class__(self.data.cpu())
def cuda(self):
return self.__class__(self.data.cuda())
def float(self, device):
return self.__class__(self.data.float())
def double(self, device):
return self.__class__(self.data.double())
def unbind(self, dim=0):
return [self.__class__(x) for x in self.data.unbind(dim=dim)]
class SO3(LieGroup):
group_name = 'SO3'
group_id = 1
manifold_dim = 3
embedded_dim = 4
# unit quaternion
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0])
def __init__(self, data):
if isinstance(data, SE3):
data = data.data[..., 3:7]
super(SO3, self).__init__(data)
class RxSO3(LieGroup):
group_name = 'RxSO3'
group_id = 2
manifold_dim = 4
embedded_dim = 5
# unit quaternion
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0, 1.0])
def __init__(self, data):
if isinstance(data, Sim3):
data = data.data[..., 3:8]
super(RxSO3, self).__init__(data)
class SE3(LieGroup):
group_name = 'SE3'
group_id = 3
manifold_dim = 6
embedded_dim = 7
# translation, unit quaternion
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
def __init__(self, data):
if isinstance(data, SO3):
translation = torch.zeros_like(data.data[...,:3])
data = torch.cat([translation, data.data], -1)
super(SE3, self).__init__(data)
def scale(self, s):
t, q = self.data.split([3,4], -1)
t = t * s.unsqueeze(-1)
return SE3(torch.cat([t, q], dim=-1))
def translation(self):
return ExtractTranslation.apply(self.data)
class Sim3(LieGroup):
group_name = 'Sim3'
group_id = 4
manifold_dim = 7
embedded_dim = 8
# translation, unit quaternion, scale
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0])
def __init__(self, data):
if isinstance(data, SO3):
scale = torch.ones_like(SO3.data[...,:1])
translation = torch.zeros_like(SO3.data[...,:3])
data = torch.cat([translation, SO3.data, scale], -1)
elif isinstance(data, SE3):
scale = torch.ones_like(data.data[...,:1])
data = torch.cat([data.data, scale], -1)
elif isinstance(data, Sim3):
data = data.data
super(Sim3, self).__init__(data)
def cat(group_list, dim):
""" Concatenate groups along dimension """
data = torch.cat([X.data for X in group_list], dim=dim)
return group_list[0].__class__(data)
def stack(group_list, dim):
""" Concatenate groups along dimension """
data = torch.stack([X.data for X in group_list], dim=dim)
return group_list[0].__class__(data)
#ifndef COMMON_H
#define COMMON_H
#define EIGEN_DEFAULT_DENSE_INDEX_TYPE int
#define EIGEN_RUNTIME_NO_MALLOC
#define EPS 1e-6
#define PI 3.14159265358979323846
#endif
#ifndef DISPATCH_H
#define DISPATCH_H
#include <torch/extension.h>
#include "so3.h"
#include "rxso3.h"
#include "se3.h"
#include "sim3.h"
#define PRIVATE_CASE_TYPE(group_index, enum_type, type, ...) \
case enum_type: { \
using scalar_t = type; \
switch (group_index) { \
case 1: { \
using group_t = SO3<type>; \
return __VA_ARGS__(); \
} \
case 2: { \
using group_t = RxSO3<type>; \
return __VA_ARGS__(); \
} \
case 3: { \
using group_t = SE3<type>; \
return __VA_ARGS__(); \
} \
case 4: { \
using group_t = Sim3<type>; \
return __VA_ARGS__(); \
} \
} \
} \
#define DISPATCH_GROUP_AND_FLOATING_TYPES(GROUP_INDEX, TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Double, double, __VA_ARGS__) \
PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Float, float, __VA_ARGS__) \
} \
}()
#endif
#ifndef LIETORCH_CPU_H_
#define LIETORCH_CPU_H_
#include <vector>
#include <torch/extension.h>
// unary operations
torch::Tensor exp_forward_cpu(int, torch::Tensor);
std::vector<torch::Tensor> exp_backward_cpu(int, torch::Tensor, torch::Tensor);
torch::Tensor log_forward_cpu(int, torch::Tensor);
std::vector<torch::Tensor> log_backward_cpu(int, torch::Tensor, torch::Tensor);
torch::Tensor inv_forward_cpu(int, torch::Tensor);
std::vector<torch::Tensor> inv_backward_cpu(int, torch::Tensor, torch::Tensor);
// binary operations
torch::Tensor mul_forward_cpu(int, torch::Tensor, torch::Tensor);
std::vector<torch::Tensor> mul_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
torch::Tensor adj_forward_cpu(int, torch::Tensor, torch::Tensor);
std::vector<torch::Tensor> adj_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
torch::Tensor adjT_forward_cpu(int, torch::Tensor, torch::Tensor);
std::vector<torch::Tensor> adjT_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
torch::Tensor act_forward_cpu(int, torch::Tensor, torch::Tensor);
std::vector<torch::Tensor> act_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
torch::Tensor act4_forward_cpu(int, torch::Tensor, torch::Tensor);
std::vector<torch::Tensor> act4_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
// utility operations
torch::Tensor as_matrix_forward_cpu(int, torch::Tensor);
torch::Tensor jleft_forward_cpu(int, torch::Tensor, torch::Tensor);
#endif
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment