Commit bf5f3526 authored by Nicolas Gorlo's avatar Nicolas Gorlo
Browse files

updated deprecated methods

parent 0fa9ce8f
...@@ -142,7 +142,7 @@ std::vector<torch::Tensor> corr_index_cuda_forward( ...@@ -142,7 +142,7 @@ std::vector<torch::Tensor> corr_index_cuda_forward(
torch::Tensor corr = torch::zeros( torch::Tensor corr = torch::zeros(
{batch_size, 2*radius+1, 2*radius+1, ht, wd}, opts); {batch_size, 2*radius+1, 2*radius+1, ht, wd}, opts);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_forward_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.scalar_type(), "sampler_forward_kernel", ([&] {
corr_index_forward_kernel<scalar_t><<<blocks, threads>>>( corr_index_forward_kernel<scalar_t><<<blocks, threads>>>(
volume.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(), volume.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(), coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
...@@ -173,7 +173,7 @@ std::vector<torch::Tensor> corr_index_cuda_backward( ...@@ -173,7 +173,7 @@ std::vector<torch::Tensor> corr_index_cuda_backward(
const dim3 threads(BLOCK, BLOCK); const dim3 threads(BLOCK, BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_backward_kernel", ([&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.scalar_type(), "sampler_backward_kernel", ([&] {
corr_index_backward_kernel<scalar_t><<<blocks, threads>>>( corr_index_backward_kernel<scalar_t><<<blocks, threads>>>(
coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(), coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
corr_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(), corr_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
// C++ interface // C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
......
...@@ -357,7 +357,7 @@ torch::Tensor exp_forward_cpu(int group_id, torch::Tensor a) { ...@@ -357,7 +357,7 @@ torch::Tensor exp_forward_cpu(int group_id, torch::Tensor a) {
int batch_size = a.size(0); int batch_size = a.size(0);
torch::Tensor X; torch::Tensor X;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.scalar_type(), "exp_forward_kernel", ([&] {
X = torch::zeros({batch_size, group_t::N}, a.options()); X = torch::zeros({batch_size, group_t::N}, a.options());
exp_forward_kernel<group_t, scalar_t>( exp_forward_kernel<group_t, scalar_t>(
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
...@@ -372,7 +372,7 @@ std::vector<torch::Tensor> exp_backward_cpu(int group_id, torch::Tensor grad, to ...@@ -372,7 +372,7 @@ std::vector<torch::Tensor> exp_backward_cpu(int group_id, torch::Tensor grad, to
int batch_size = a.size(0); int batch_size = a.size(0);
torch::Tensor da = torch::zeros(a.sizes(), grad.options()); torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.scalar_type(), "exp_backward_kernel", ([&] {
exp_backward_kernel<group_t, scalar_t>( exp_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
...@@ -387,7 +387,7 @@ torch::Tensor log_forward_cpu(int group_id, torch::Tensor X) { ...@@ -387,7 +387,7 @@ torch::Tensor log_forward_cpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor a; torch::Tensor a;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "log_forward_kernel", ([&] {
a = torch::zeros({batch_size, group_t::K}, X.options()); a = torch::zeros({batch_size, group_t::K}, X.options());
log_forward_kernel<group_t, scalar_t>( log_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -402,7 +402,7 @@ std::vector<torch::Tensor> log_backward_cpu(int group_id, torch::Tensor grad, to ...@@ -402,7 +402,7 @@ std::vector<torch::Tensor> log_backward_cpu(int group_id, torch::Tensor grad, to
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "log_backward_kernel", ([&] {
log_backward_kernel<group_t, scalar_t>( log_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -417,7 +417,7 @@ torch::Tensor inv_forward_cpu(int group_id, torch::Tensor X) { ...@@ -417,7 +417,7 @@ torch::Tensor inv_forward_cpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor Y = torch::zeros_like(X); torch::Tensor Y = torch::zeros_like(X);
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "inv_forward_kernel", ([&] {
inv_forward_kernel<group_t, scalar_t>( inv_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
Y.data_ptr<scalar_t>(), Y.data_ptr<scalar_t>(),
...@@ -431,7 +431,7 @@ std::vector<torch::Tensor> inv_backward_cpu(int group_id, torch::Tensor grad, to ...@@ -431,7 +431,7 @@ std::vector<torch::Tensor> inv_backward_cpu(int group_id, torch::Tensor grad, to
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "inv_backward_kernel", ([&] {
inv_backward_kernel<group_t, scalar_t>( inv_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -447,7 +447,7 @@ torch::Tensor mul_forward_cpu(int group_id, torch::Tensor X, torch::Tensor Y) { ...@@ -447,7 +447,7 @@ torch::Tensor mul_forward_cpu(int group_id, torch::Tensor X, torch::Tensor Y) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor Z = torch::zeros_like(X); torch::Tensor Z = torch::zeros_like(X);
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "mul_forward_kernel", ([&] {
mul_forward_kernel<group_t, scalar_t>( mul_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
Y.data_ptr<scalar_t>(), Y.data_ptr<scalar_t>(),
...@@ -463,7 +463,7 @@ std::vector<torch::Tensor> mul_backward_cpu(int group_id, torch::Tensor grad, to ...@@ -463,7 +463,7 @@ std::vector<torch::Tensor> mul_backward_cpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dY = torch::zeros(Y.sizes(), grad.options()); torch::Tensor dY = torch::zeros(Y.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "mul_backward_kernel", ([&] {
mul_backward_kernel<group_t, scalar_t>( mul_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -480,7 +480,7 @@ torch::Tensor adj_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) { ...@@ -480,7 +480,7 @@ torch::Tensor adj_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options()); torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adj_forward_kernel", ([&] {
adj_forward_kernel<group_t, scalar_t>( adj_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
...@@ -496,7 +496,7 @@ std::vector<torch::Tensor> adj_backward_cpu(int group_id, torch::Tensor grad, to ...@@ -496,7 +496,7 @@ std::vector<torch::Tensor> adj_backward_cpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor da = torch::zeros(a.sizes(), grad.options()); torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adj_backward_kernel", ([&] {
adj_backward_kernel<group_t, scalar_t>( adj_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -514,7 +514,7 @@ torch::Tensor adjT_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) { ...@@ -514,7 +514,7 @@ torch::Tensor adjT_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options()); torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adjT_forward_kernel", ([&] {
adjT_forward_kernel<group_t, scalar_t>( adjT_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
...@@ -530,7 +530,7 @@ std::vector<torch::Tensor> adjT_backward_cpu(int group_id, torch::Tensor grad, t ...@@ -530,7 +530,7 @@ std::vector<torch::Tensor> adjT_backward_cpu(int group_id, torch::Tensor grad, t
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor da = torch::zeros(a.sizes(), grad.options()); torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adjT_backward_kernel", ([&] {
adjT_backward_kernel<group_t, scalar_t>( adjT_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -548,7 +548,7 @@ torch::Tensor act_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) { ...@@ -548,7 +548,7 @@ torch::Tensor act_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor q = torch::zeros(p.sizes(), p.options()); torch::Tensor q = torch::zeros(p.sizes(), p.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act_forward_kernel", ([&] {
act_forward_kernel<group_t, scalar_t>( act_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
p.data_ptr<scalar_t>(), p.data_ptr<scalar_t>(),
...@@ -564,7 +564,7 @@ std::vector<torch::Tensor> act_backward_cpu(int group_id, torch::Tensor grad, to ...@@ -564,7 +564,7 @@ std::vector<torch::Tensor> act_backward_cpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dp = torch::zeros(p.sizes(), grad.options()); torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act_backward_kernel", ([&] {
act_backward_kernel<group_t, scalar_t>( act_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -582,7 +582,7 @@ torch::Tensor act4_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) { ...@@ -582,7 +582,7 @@ torch::Tensor act4_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor q = torch::zeros(p.sizes(), p.options()); torch::Tensor q = torch::zeros(p.sizes(), p.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act4_forward_kernel", ([&] {
act4_forward_kernel<group_t, scalar_t>( act4_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
p.data_ptr<scalar_t>(), p.data_ptr<scalar_t>(),
...@@ -598,7 +598,7 @@ std::vector<torch::Tensor> act4_backward_cpu(int group_id, torch::Tensor grad, t ...@@ -598,7 +598,7 @@ std::vector<torch::Tensor> act4_backward_cpu(int group_id, torch::Tensor grad, t
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dp = torch::zeros(p.sizes(), grad.options()); torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act4_backward_kernel", ([&] {
act4_backward_kernel<group_t, scalar_t>( act4_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -616,7 +616,7 @@ torch::Tensor as_matrix_forward_cpu(int group_id, torch::Tensor X) { ...@@ -616,7 +616,7 @@ torch::Tensor as_matrix_forward_cpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options()); torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "as_matrix_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "as_matrix_forward_kernel", ([&] {
as_matrix_forward_kernel<group_t, scalar_t>( as_matrix_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
T4x4.data_ptr<scalar_t>(), T4x4.data_ptr<scalar_t>(),
...@@ -631,7 +631,7 @@ torch::Tensor orthogonal_projector_cpu(int group_id, torch::Tensor X) { ...@@ -631,7 +631,7 @@ torch::Tensor orthogonal_projector_cpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor P; torch::Tensor P;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "orthogonal_projector_kernel", ([&] {
P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options()); P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
orthogonal_projector_kernel<group_t, scalar_t>(X.data_ptr<scalar_t>(), P.data_ptr<scalar_t>(), batch_size); orthogonal_projector_kernel<group_t, scalar_t>(X.data_ptr<scalar_t>(), P.data_ptr<scalar_t>(), batch_size);
})); }));
...@@ -645,7 +645,7 @@ torch::Tensor jleft_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) ...@@ -645,7 +645,7 @@ torch::Tensor jleft_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a)
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options()); torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "jleft_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "jleft_forward_kernel", ([&] {
jleft_forward_kernel<group_t, scalar_t>( jleft_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
......
...@@ -299,7 +299,7 @@ torch::Tensor exp_forward_gpu(int group_id, torch::Tensor a) { ...@@ -299,7 +299,7 @@ torch::Tensor exp_forward_gpu(int group_id, torch::Tensor a) {
int batch_size = a.size(0); int batch_size = a.size(0);
torch::Tensor X; torch::Tensor X;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.scalar_type(), "exp_forward_kernel", ([&] {
X = torch::zeros({batch_size, group_t::N}, a.options()); X = torch::zeros({batch_size, group_t::N}, a.options());
exp_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( exp_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
...@@ -314,7 +314,7 @@ std::vector<torch::Tensor> exp_backward_gpu(int group_id, torch::Tensor grad, to ...@@ -314,7 +314,7 @@ std::vector<torch::Tensor> exp_backward_gpu(int group_id, torch::Tensor grad, to
int batch_size = a.size(0); int batch_size = a.size(0);
torch::Tensor da = torch::zeros(a.sizes(), grad.options()); torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.scalar_type(), "exp_backward_kernel", ([&] {
exp_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( exp_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
...@@ -329,7 +329,7 @@ torch::Tensor log_forward_gpu(int group_id, torch::Tensor X) { ...@@ -329,7 +329,7 @@ torch::Tensor log_forward_gpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor a; torch::Tensor a;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "log_forward_kernel", ([&] {
a = torch::zeros({batch_size, group_t::K}, X.options()); a = torch::zeros({batch_size, group_t::K}, X.options());
log_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( log_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -344,7 +344,7 @@ std::vector<torch::Tensor> log_backward_gpu(int group_id, torch::Tensor grad, to ...@@ -344,7 +344,7 @@ std::vector<torch::Tensor> log_backward_gpu(int group_id, torch::Tensor grad, to
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "log_backward_kernel", ([&] {
log_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( log_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -359,7 +359,7 @@ torch::Tensor inv_forward_gpu(int group_id, torch::Tensor X) { ...@@ -359,7 +359,7 @@ torch::Tensor inv_forward_gpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor Y = torch::zeros_like(X); torch::Tensor Y = torch::zeros_like(X);
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "inv_forward_kernel", ([&] {
inv_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( inv_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
Y.data_ptr<scalar_t>(), Y.data_ptr<scalar_t>(),
...@@ -373,7 +373,7 @@ std::vector<torch::Tensor> inv_backward_gpu(int group_id, torch::Tensor grad, to ...@@ -373,7 +373,7 @@ std::vector<torch::Tensor> inv_backward_gpu(int group_id, torch::Tensor grad, to
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "inv_backward_kernel", ([&] {
inv_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( inv_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -389,7 +389,7 @@ torch::Tensor mul_forward_gpu(int group_id, torch::Tensor X, torch::Tensor Y) { ...@@ -389,7 +389,7 @@ torch::Tensor mul_forward_gpu(int group_id, torch::Tensor X, torch::Tensor Y) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor Z = torch::zeros_like(X); torch::Tensor Z = torch::zeros_like(X);
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "mul_forward_kernel", ([&] {
mul_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( mul_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
Y.data_ptr<scalar_t>(), Y.data_ptr<scalar_t>(),
...@@ -405,7 +405,7 @@ std::vector<torch::Tensor> mul_backward_gpu(int group_id, torch::Tensor grad, to ...@@ -405,7 +405,7 @@ std::vector<torch::Tensor> mul_backward_gpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dY = torch::zeros(Y.sizes(), grad.options()); torch::Tensor dY = torch::zeros(Y.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "mul_backward_kernel", ([&] {
mul_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( mul_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -422,7 +422,7 @@ torch::Tensor adj_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) { ...@@ -422,7 +422,7 @@ torch::Tensor adj_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options()); torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adj_forward_kernel", ([&] {
adj_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( adj_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
...@@ -438,7 +438,7 @@ std::vector<torch::Tensor> adj_backward_gpu(int group_id, torch::Tensor grad, to ...@@ -438,7 +438,7 @@ std::vector<torch::Tensor> adj_backward_gpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor da = torch::zeros(a.sizes(), grad.options()); torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adj_backward_kernel", ([&] {
adj_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( adj_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -456,7 +456,7 @@ torch::Tensor adjT_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) { ...@@ -456,7 +456,7 @@ torch::Tensor adjT_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options()); torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adjT_forward_kernel", ([&] {
adjT_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( adjT_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
...@@ -472,7 +472,7 @@ std::vector<torch::Tensor> adjT_backward_gpu(int group_id, torch::Tensor grad, t ...@@ -472,7 +472,7 @@ std::vector<torch::Tensor> adjT_backward_gpu(int group_id, torch::Tensor grad, t
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor da = torch::zeros(a.sizes(), grad.options()); torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adjT_backward_kernel", ([&] {
adjT_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( adjT_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -491,7 +491,7 @@ torch::Tensor act_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) { ...@@ -491,7 +491,7 @@ torch::Tensor act_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor q = torch::zeros(p.sizes(), p.options()); torch::Tensor q = torch::zeros(p.sizes(), p.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act_forward_kernel", ([&] {
act_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( act_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
p.data_ptr<scalar_t>(), p.data_ptr<scalar_t>(),
...@@ -507,7 +507,7 @@ std::vector<torch::Tensor> act_backward_gpu(int group_id, torch::Tensor grad, to ...@@ -507,7 +507,7 @@ std::vector<torch::Tensor> act_backward_gpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dp = torch::zeros(p.sizes(), grad.options()); torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act_backward_kernel", ([&] {
act_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( act_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -524,7 +524,7 @@ torch::Tensor act4_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) { ...@@ -524,7 +524,7 @@ torch::Tensor act4_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor q = torch::zeros(p.sizes(), p.options()); torch::Tensor q = torch::zeros(p.sizes(), p.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act4_forward_kernel", ([&] {
act4_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( act4_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
p.data_ptr<scalar_t>(), p.data_ptr<scalar_t>(),
...@@ -540,7 +540,7 @@ std::vector<torch::Tensor> act4_backward_gpu(int group_id, torch::Tensor grad, t ...@@ -540,7 +540,7 @@ std::vector<torch::Tensor> act4_backward_gpu(int group_id, torch::Tensor grad, t
torch::Tensor dX = torch::zeros(X.sizes(), grad.options()); torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dp = torch::zeros(p.sizes(), grad.options()); torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_backward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act4_backward_kernel", ([&] {
act4_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( act4_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -558,7 +558,7 @@ torch::Tensor as_matrix_forward_gpu(int group_id, torch::Tensor X) { ...@@ -558,7 +558,7 @@ torch::Tensor as_matrix_forward_gpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options()); torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "as_matrix_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "as_matrix_forward_kernel", ([&] {
as_matrix_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( as_matrix_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
T4x4.data_ptr<scalar_t>(), T4x4.data_ptr<scalar_t>(),
...@@ -573,7 +573,7 @@ torch::Tensor orthogonal_projector_gpu(int group_id, torch::Tensor X) { ...@@ -573,7 +573,7 @@ torch::Tensor orthogonal_projector_gpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor P; torch::Tensor P;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "orthogonal_projector_kernel", ([&] {
P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options()); P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
orthogonal_projector_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( orthogonal_projector_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
...@@ -589,7 +589,7 @@ torch::Tensor jleft_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) ...@@ -589,7 +589,7 @@ torch::Tensor jleft_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a)
int batch_size = X.size(0); int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options()); torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "jleft_forward_kernel", ([&] { DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "jleft_forward_kernel", ([&] {
jleft_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>( jleft_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(), X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(), a.data_ptr<scalar_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment