Commit bf5f3526 authored by Nicolas Gorlo's avatar Nicolas Gorlo
Browse files

updated deprecated methods

parent 0fa9ce8f
......@@ -142,7 +142,7 @@ std::vector<torch::Tensor> corr_index_cuda_forward(
torch::Tensor corr = torch::zeros(
{batch_size, 2*radius+1, 2*radius+1, ht, wd}, opts);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_forward_kernel", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.scalar_type(), "sampler_forward_kernel", ([&] {
corr_index_forward_kernel<scalar_t><<<blocks, threads>>>(
volume.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
......@@ -173,7 +173,7 @@ std::vector<torch::Tensor> corr_index_cuda_backward(
const dim3 threads(BLOCK, BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.type(), "sampler_backward_kernel", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(volume.scalar_type(), "sampler_backward_kernel", ([&] {
corr_index_backward_kernel<scalar_t><<<blocks, threads>>>(
coords.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
corr_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
......
......@@ -3,7 +3,7 @@
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
......
......@@ -357,7 +357,7 @@ torch::Tensor exp_forward_cpu(int group_id, torch::Tensor a) {
int batch_size = a.size(0);
torch::Tensor X;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.scalar_type(), "exp_forward_kernel", ([&] {
X = torch::zeros({batch_size, group_t::N}, a.options());
exp_forward_kernel<group_t, scalar_t>(
a.data_ptr<scalar_t>(),
......@@ -372,7 +372,7 @@ std::vector<torch::Tensor> exp_backward_cpu(int group_id, torch::Tensor grad, to
int batch_size = a.size(0);
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.scalar_type(), "exp_backward_kernel", ([&] {
exp_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(),
......@@ -387,7 +387,7 @@ torch::Tensor log_forward_cpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor a;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "log_forward_kernel", ([&] {
a = torch::zeros({batch_size, group_t::K}, X.options());
log_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(),
......@@ -402,7 +402,7 @@ std::vector<torch::Tensor> log_backward_cpu(int group_id, torch::Tensor grad, to
int batch_size = X.size(0);
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "log_backward_kernel", ([&] {
log_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -417,7 +417,7 @@ torch::Tensor inv_forward_cpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor Y = torch::zeros_like(X);
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "inv_forward_kernel", ([&] {
inv_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(),
Y.data_ptr<scalar_t>(),
......@@ -431,7 +431,7 @@ std::vector<torch::Tensor> inv_backward_cpu(int group_id, torch::Tensor grad, to
int batch_size = X.size(0);
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "inv_backward_kernel", ([&] {
inv_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -447,7 +447,7 @@ torch::Tensor mul_forward_cpu(int group_id, torch::Tensor X, torch::Tensor Y) {
int batch_size = X.size(0);
torch::Tensor Z = torch::zeros_like(X);
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "mul_forward_kernel", ([&] {
mul_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(),
Y.data_ptr<scalar_t>(),
......@@ -463,7 +463,7 @@ std::vector<torch::Tensor> mul_backward_cpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dY = torch::zeros(Y.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "mul_backward_kernel", ([&] {
mul_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -480,7 +480,7 @@ torch::Tensor adj_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adj_forward_kernel", ([&] {
adj_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(),
......@@ -496,7 +496,7 @@ std::vector<torch::Tensor> adj_backward_cpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adj_backward_kernel", ([&] {
adj_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -514,7 +514,7 @@ torch::Tensor adjT_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adjT_forward_kernel", ([&] {
adjT_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(),
......@@ -530,7 +530,7 @@ std::vector<torch::Tensor> adjT_backward_cpu(int group_id, torch::Tensor grad, t
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adjT_backward_kernel", ([&] {
adjT_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -548,7 +548,7 @@ torch::Tensor act_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
int batch_size = X.size(0);
torch::Tensor q = torch::zeros(p.sizes(), p.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act_forward_kernel", ([&] {
act_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(),
p.data_ptr<scalar_t>(),
......@@ -564,7 +564,7 @@ std::vector<torch::Tensor> act_backward_cpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act_backward_kernel", ([&] {
act_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -582,7 +582,7 @@ torch::Tensor act4_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
int batch_size = X.size(0);
torch::Tensor q = torch::zeros(p.sizes(), p.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act4_forward_kernel", ([&] {
act4_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(),
p.data_ptr<scalar_t>(),
......@@ -598,7 +598,7 @@ std::vector<torch::Tensor> act4_backward_cpu(int group_id, torch::Tensor grad, t
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act4_backward_kernel", ([&] {
act4_backward_kernel<group_t, scalar_t>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -616,7 +616,7 @@ torch::Tensor as_matrix_forward_cpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "as_matrix_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "as_matrix_forward_kernel", ([&] {
as_matrix_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(),
T4x4.data_ptr<scalar_t>(),
......@@ -631,7 +631,7 @@ torch::Tensor orthogonal_projector_cpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor P;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "orthogonal_projector_kernel", ([&] {
P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
orthogonal_projector_kernel<group_t, scalar_t>(X.data_ptr<scalar_t>(), P.data_ptr<scalar_t>(), batch_size);
}));
......@@ -645,7 +645,7 @@ torch::Tensor jleft_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a)
int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "jleft_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "jleft_forward_kernel", ([&] {
jleft_forward_kernel<group_t, scalar_t>(
X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(),
......
......@@ -299,7 +299,7 @@ torch::Tensor exp_forward_gpu(int group_id, torch::Tensor a) {
int batch_size = a.size(0);
torch::Tensor X;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.scalar_type(), "exp_forward_kernel", ([&] {
X = torch::zeros({batch_size, group_t::N}, a.options());
exp_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
a.data_ptr<scalar_t>(),
......@@ -314,7 +314,7 @@ std::vector<torch::Tensor> exp_backward_gpu(int group_id, torch::Tensor grad, to
int batch_size = a.size(0);
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.scalar_type(), "exp_backward_kernel", ([&] {
exp_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(),
......@@ -329,7 +329,7 @@ torch::Tensor log_forward_gpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor a;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "log_forward_kernel", ([&] {
a = torch::zeros({batch_size, group_t::K}, X.options());
log_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
......@@ -344,7 +344,7 @@ std::vector<torch::Tensor> log_backward_gpu(int group_id, torch::Tensor grad, to
int batch_size = X.size(0);
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "log_backward_kernel", ([&] {
log_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -359,7 +359,7 @@ torch::Tensor inv_forward_gpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor Y = torch::zeros_like(X);
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "inv_forward_kernel", ([&] {
inv_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
Y.data_ptr<scalar_t>(),
......@@ -373,7 +373,7 @@ std::vector<torch::Tensor> inv_backward_gpu(int group_id, torch::Tensor grad, to
int batch_size = X.size(0);
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "inv_backward_kernel", ([&] {
inv_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -389,7 +389,7 @@ torch::Tensor mul_forward_gpu(int group_id, torch::Tensor X, torch::Tensor Y) {
int batch_size = X.size(0);
torch::Tensor Z = torch::zeros_like(X);
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "mul_forward_kernel", ([&] {
mul_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
Y.data_ptr<scalar_t>(),
......@@ -405,7 +405,7 @@ std::vector<torch::Tensor> mul_backward_gpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dY = torch::zeros(Y.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "mul_backward_kernel", ([&] {
mul_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -422,7 +422,7 @@ torch::Tensor adj_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adj_forward_kernel", ([&] {
adj_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(),
......@@ -438,7 +438,7 @@ std::vector<torch::Tensor> adj_backward_gpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adj_backward_kernel", ([&] {
adj_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -456,7 +456,7 @@ torch::Tensor adjT_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adjT_forward_kernel", ([&] {
adjT_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(),
......@@ -472,7 +472,7 @@ std::vector<torch::Tensor> adjT_backward_gpu(int group_id, torch::Tensor grad, t
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "adjT_backward_kernel", ([&] {
adjT_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -491,7 +491,7 @@ torch::Tensor act_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
int batch_size = X.size(0);
torch::Tensor q = torch::zeros(p.sizes(), p.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act_forward_kernel", ([&] {
act_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
p.data_ptr<scalar_t>(),
......@@ -507,7 +507,7 @@ std::vector<torch::Tensor> act_backward_gpu(int group_id, torch::Tensor grad, to
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act_backward_kernel", ([&] {
act_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -524,7 +524,7 @@ torch::Tensor act4_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
int batch_size = X.size(0);
torch::Tensor q = torch::zeros(p.sizes(), p.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act4_forward_kernel", ([&] {
act4_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
p.data_ptr<scalar_t>(),
......@@ -540,7 +540,7 @@ std::vector<torch::Tensor> act4_backward_gpu(int group_id, torch::Tensor grad, t
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_backward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "act4_backward_kernel", ([&] {
act4_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
grad.data_ptr<scalar_t>(),
X.data_ptr<scalar_t>(),
......@@ -558,7 +558,7 @@ torch::Tensor as_matrix_forward_gpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "as_matrix_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "as_matrix_forward_kernel", ([&] {
as_matrix_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
T4x4.data_ptr<scalar_t>(),
......@@ -573,7 +573,7 @@ torch::Tensor orthogonal_projector_gpu(int group_id, torch::Tensor X) {
int batch_size = X.size(0);
torch::Tensor P;
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "orthogonal_projector_kernel", ([&] {
P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
orthogonal_projector_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
......@@ -589,7 +589,7 @@ torch::Tensor jleft_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a)
int batch_size = X.size(0);
torch::Tensor b = torch::zeros(a.sizes(), a.options());
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "jleft_forward_kernel", ([&] {
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.scalar_type(), "jleft_forward_kernel", ([&] {
jleft_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
X.data_ptr<scalar_t>(),
a.data_ptr<scalar_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment