Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
lietorch
Commits
bf5f3526
Commit
bf5f3526
authored
Apr 30, 2025
by
Nicolas Gorlo
Browse files
updated deprecated methods
parent
0fa9ce8f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
41 deletions
+41
-41
lietorch/extras/corr_index_kernel.cu
lietorch/extras/corr_index_kernel.cu
+2
-2
lietorch/extras/extras.cpp
lietorch/extras/extras.cpp
+1
-1
lietorch/src/lietorch_cpu.cpp
lietorch/src/lietorch_cpu.cpp
+19
-19
lietorch/src/lietorch_gpu.cu
lietorch/src/lietorch_gpu.cu
+19
-19
No files found.
lietorch/extras/corr_index_kernel.cu
View file @
bf5f3526
...
@@ -142,7 +142,7 @@ std::vector<torch::Tensor> corr_index_cuda_forward(
...
@@ -142,7 +142,7 @@ std::vector<torch::Tensor> corr_index_cuda_forward(
torch
::
Tensor
corr
=
torch
::
zeros
(
torch
::
Tensor
corr
=
torch
::
zeros
(
{
batch_size
,
2
*
radius
+
1
,
2
*
radius
+
1
,
ht
,
wd
},
opts
);
{
batch_size
,
2
*
radius
+
1
,
2
*
radius
+
1
,
ht
,
wd
},
opts
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
volume
.
type
(),
"sampler_forward_kernel"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
volume
.
scalar_
type
(),
"sampler_forward_kernel"
,
([
&
]
{
corr_index_forward_kernel
<
scalar_t
><<<
blocks
,
threads
>>>
(
corr_index_forward_kernel
<
scalar_t
><<<
blocks
,
threads
>>>
(
volume
.
packed_accessor32
<
scalar_t
,
5
,
torch
::
RestrictPtrTraits
>
(),
volume
.
packed_accessor32
<
scalar_t
,
5
,
torch
::
RestrictPtrTraits
>
(),
coords
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
coords
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
...
@@ -173,7 +173,7 @@ std::vector<torch::Tensor> corr_index_cuda_backward(
...
@@ -173,7 +173,7 @@ std::vector<torch::Tensor> corr_index_cuda_backward(
const
dim3
threads
(
BLOCK
,
BLOCK
);
const
dim3
threads
(
BLOCK
,
BLOCK
);
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
volume
.
type
(),
"sampler_backward_kernel"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
volume
.
scalar_
type
(),
"sampler_backward_kernel"
,
([
&
]
{
corr_index_backward_kernel
<
scalar_t
><<<
blocks
,
threads
>>>
(
corr_index_backward_kernel
<
scalar_t
><<<
blocks
,
threads
>>>
(
coords
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
coords
.
packed_accessor32
<
float
,
4
,
torch
::
RestrictPtrTraits
>
(),
corr_grad
.
packed_accessor32
<
scalar_t
,
5
,
torch
::
RestrictPtrTraits
>
(),
corr_grad
.
packed_accessor32
<
scalar_t
,
5
,
torch
::
RestrictPtrTraits
>
(),
...
...
lietorch/extras/extras.cpp
View file @
bf5f3526
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
// C++ interface
// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.
type().
is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...
...
lietorch/src/lietorch_cpu.cpp
View file @
bf5f3526
...
@@ -357,7 +357,7 @@ torch::Tensor exp_forward_cpu(int group_id, torch::Tensor a) {
...
@@ -357,7 +357,7 @@ torch::Tensor exp_forward_cpu(int group_id, torch::Tensor a) {
int
batch_size
=
a
.
size
(
0
);
int
batch_size
=
a
.
size
(
0
);
torch
::
Tensor
X
;
torch
::
Tensor
X
;
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
a
.
type
(),
"exp_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
a
.
scalar_
type
(),
"exp_forward_kernel"
,
([
&
]
{
X
=
torch
::
zeros
({
batch_size
,
group_t
::
N
},
a
.
options
());
X
=
torch
::
zeros
({
batch_size
,
group_t
::
N
},
a
.
options
());
exp_forward_kernel
<
group_t
,
scalar_t
>
(
exp_forward_kernel
<
group_t
,
scalar_t
>
(
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
@@ -372,7 +372,7 @@ std::vector<torch::Tensor> exp_backward_cpu(int group_id, torch::Tensor grad, to
...
@@ -372,7 +372,7 @@ std::vector<torch::Tensor> exp_backward_cpu(int group_id, torch::Tensor grad, to
int
batch_size
=
a
.
size
(
0
);
int
batch_size
=
a
.
size
(
0
);
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
a
.
type
(),
"exp_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
a
.
scalar_
type
(),
"exp_backward_kernel"
,
([
&
]
{
exp_backward_kernel
<
group_t
,
scalar_t
>
(
exp_backward_kernel
<
group_t
,
scalar_t
>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
@@ -387,7 +387,7 @@ torch::Tensor log_forward_cpu(int group_id, torch::Tensor X) {
...
@@ -387,7 +387,7 @@ torch::Tensor log_forward_cpu(int group_id, torch::Tensor X) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
a
;
torch
::
Tensor
a
;
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"log_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"log_forward_kernel"
,
([
&
]
{
a
=
torch
::
zeros
({
batch_size
,
group_t
::
K
},
X
.
options
());
a
=
torch
::
zeros
({
batch_size
,
group_t
::
K
},
X
.
options
());
log_forward_kernel
<
group_t
,
scalar_t
>
(
log_forward_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -402,7 +402,7 @@ std::vector<torch::Tensor> log_backward_cpu(int group_id, torch::Tensor grad, to
...
@@ -402,7 +402,7 @@ std::vector<torch::Tensor> log_backward_cpu(int group_id, torch::Tensor grad, to
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"log_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"log_backward_kernel"
,
([
&
]
{
log_backward_kernel
<
group_t
,
scalar_t
>
(
log_backward_kernel
<
group_t
,
scalar_t
>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -417,7 +417,7 @@ torch::Tensor inv_forward_cpu(int group_id, torch::Tensor X) {
...
@@ -417,7 +417,7 @@ torch::Tensor inv_forward_cpu(int group_id, torch::Tensor X) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
Y
=
torch
::
zeros_like
(
X
);
torch
::
Tensor
Y
=
torch
::
zeros_like
(
X
);
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"inv_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"inv_forward_kernel"
,
([
&
]
{
inv_forward_kernel
<
group_t
,
scalar_t
>
(
inv_forward_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
Y
.
data_ptr
<
scalar_t
>
(),
Y
.
data_ptr
<
scalar_t
>
(),
...
@@ -431,7 +431,7 @@ std::vector<torch::Tensor> inv_backward_cpu(int group_id, torch::Tensor grad, to
...
@@ -431,7 +431,7 @@ std::vector<torch::Tensor> inv_backward_cpu(int group_id, torch::Tensor grad, to
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"inv_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"inv_backward_kernel"
,
([
&
]
{
inv_backward_kernel
<
group_t
,
scalar_t
>
(
inv_backward_kernel
<
group_t
,
scalar_t
>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -447,7 +447,7 @@ torch::Tensor mul_forward_cpu(int group_id, torch::Tensor X, torch::Tensor Y) {
...
@@ -447,7 +447,7 @@ torch::Tensor mul_forward_cpu(int group_id, torch::Tensor X, torch::Tensor Y) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
Z
=
torch
::
zeros_like
(
X
);
torch
::
Tensor
Z
=
torch
::
zeros_like
(
X
);
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"mul_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"mul_forward_kernel"
,
([
&
]
{
mul_forward_kernel
<
group_t
,
scalar_t
>
(
mul_forward_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
Y
.
data_ptr
<
scalar_t
>
(),
Y
.
data_ptr
<
scalar_t
>
(),
...
@@ -463,7 +463,7 @@ std::vector<torch::Tensor> mul_backward_cpu(int group_id, torch::Tensor grad, to
...
@@ -463,7 +463,7 @@ std::vector<torch::Tensor> mul_backward_cpu(int group_id, torch::Tensor grad, to
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dY
=
torch
::
zeros
(
Y
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dY
=
torch
::
zeros
(
Y
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"mul_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"mul_backward_kernel"
,
([
&
]
{
mul_backward_kernel
<
group_t
,
scalar_t
>
(
mul_backward_kernel
<
group_t
,
scalar_t
>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -480,7 +480,7 @@ torch::Tensor adj_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
...
@@ -480,7 +480,7 @@ torch::Tensor adj_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"adj_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"adj_forward_kernel"
,
([
&
]
{
adj_forward_kernel
<
group_t
,
scalar_t
>
(
adj_forward_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
@@ -496,7 +496,7 @@ std::vector<torch::Tensor> adj_backward_cpu(int group_id, torch::Tensor grad, to
...
@@ -496,7 +496,7 @@ std::vector<torch::Tensor> adj_backward_cpu(int group_id, torch::Tensor grad, to
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"adj_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"adj_backward_kernel"
,
([
&
]
{
adj_backward_kernel
<
group_t
,
scalar_t
>
(
adj_backward_kernel
<
group_t
,
scalar_t
>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -514,7 +514,7 @@ torch::Tensor adjT_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
...
@@ -514,7 +514,7 @@ torch::Tensor adjT_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"adjT_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"adjT_forward_kernel"
,
([
&
]
{
adjT_forward_kernel
<
group_t
,
scalar_t
>
(
adjT_forward_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
@@ -530,7 +530,7 @@ std::vector<torch::Tensor> adjT_backward_cpu(int group_id, torch::Tensor grad, t
...
@@ -530,7 +530,7 @@ std::vector<torch::Tensor> adjT_backward_cpu(int group_id, torch::Tensor grad, t
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"adjT_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"adjT_backward_kernel"
,
([
&
]
{
adjT_backward_kernel
<
group_t
,
scalar_t
>
(
adjT_backward_kernel
<
group_t
,
scalar_t
>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -548,7 +548,7 @@ torch::Tensor act_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
...
@@ -548,7 +548,7 @@ torch::Tensor act_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
q
=
torch
::
zeros
(
p
.
sizes
(),
p
.
options
());
torch
::
Tensor
q
=
torch
::
zeros
(
p
.
sizes
(),
p
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"act_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"act_forward_kernel"
,
([
&
]
{
act_forward_kernel
<
group_t
,
scalar_t
>
(
act_forward_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
p
.
data_ptr
<
scalar_t
>
(),
p
.
data_ptr
<
scalar_t
>
(),
...
@@ -564,7 +564,7 @@ std::vector<torch::Tensor> act_backward_cpu(int group_id, torch::Tensor grad, to
...
@@ -564,7 +564,7 @@ std::vector<torch::Tensor> act_backward_cpu(int group_id, torch::Tensor grad, to
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dp
=
torch
::
zeros
(
p
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dp
=
torch
::
zeros
(
p
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"act_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"act_backward_kernel"
,
([
&
]
{
act_backward_kernel
<
group_t
,
scalar_t
>
(
act_backward_kernel
<
group_t
,
scalar_t
>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -582,7 +582,7 @@ torch::Tensor act4_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
...
@@ -582,7 +582,7 @@ torch::Tensor act4_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
q
=
torch
::
zeros
(
p
.
sizes
(),
p
.
options
());
torch
::
Tensor
q
=
torch
::
zeros
(
p
.
sizes
(),
p
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"act4_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"act4_forward_kernel"
,
([
&
]
{
act4_forward_kernel
<
group_t
,
scalar_t
>
(
act4_forward_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
p
.
data_ptr
<
scalar_t
>
(),
p
.
data_ptr
<
scalar_t
>
(),
...
@@ -598,7 +598,7 @@ std::vector<torch::Tensor> act4_backward_cpu(int group_id, torch::Tensor grad, t
...
@@ -598,7 +598,7 @@ std::vector<torch::Tensor> act4_backward_cpu(int group_id, torch::Tensor grad, t
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dp
=
torch
::
zeros
(
p
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dp
=
torch
::
zeros
(
p
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"act4_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"act4_backward_kernel"
,
([
&
]
{
act4_backward_kernel
<
group_t
,
scalar_t
>
(
act4_backward_kernel
<
group_t
,
scalar_t
>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -616,7 +616,7 @@ torch::Tensor as_matrix_forward_cpu(int group_id, torch::Tensor X) {
...
@@ -616,7 +616,7 @@ torch::Tensor as_matrix_forward_cpu(int group_id, torch::Tensor X) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
T4x4
=
torch
::
zeros
({
X
.
size
(
0
),
4
,
4
},
X
.
options
());
torch
::
Tensor
T4x4
=
torch
::
zeros
({
X
.
size
(
0
),
4
,
4
},
X
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"as_matrix_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"as_matrix_forward_kernel"
,
([
&
]
{
as_matrix_forward_kernel
<
group_t
,
scalar_t
>
(
as_matrix_forward_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
T4x4
.
data_ptr
<
scalar_t
>
(),
T4x4
.
data_ptr
<
scalar_t
>
(),
...
@@ -631,7 +631,7 @@ torch::Tensor orthogonal_projector_cpu(int group_id, torch::Tensor X) {
...
@@ -631,7 +631,7 @@ torch::Tensor orthogonal_projector_cpu(int group_id, torch::Tensor X) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
P
;
torch
::
Tensor
P
;
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"orthogonal_projector_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"orthogonal_projector_kernel"
,
([
&
]
{
P
=
torch
::
zeros
({
X
.
size
(
0
),
group_t
::
N
,
group_t
::
N
},
X
.
options
());
P
=
torch
::
zeros
({
X
.
size
(
0
),
group_t
::
N
,
group_t
::
N
},
X
.
options
());
orthogonal_projector_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
P
.
data_ptr
<
scalar_t
>
(),
batch_size
);
orthogonal_projector_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
P
.
data_ptr
<
scalar_t
>
(),
batch_size
);
}));
}));
...
@@ -645,7 +645,7 @@ torch::Tensor jleft_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a)
...
@@ -645,7 +645,7 @@ torch::Tensor jleft_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a)
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"jleft_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"jleft_forward_kernel"
,
([
&
]
{
jleft_forward_kernel
<
group_t
,
scalar_t
>
(
jleft_forward_kernel
<
group_t
,
scalar_t
>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
...
lietorch/src/lietorch_gpu.cu
View file @
bf5f3526
...
@@ -299,7 +299,7 @@ torch::Tensor exp_forward_gpu(int group_id, torch::Tensor a) {
...
@@ -299,7 +299,7 @@ torch::Tensor exp_forward_gpu(int group_id, torch::Tensor a) {
int
batch_size
=
a
.
size
(
0
);
int
batch_size
=
a
.
size
(
0
);
torch
::
Tensor
X
;
torch
::
Tensor
X
;
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
a
.
type
(),
"exp_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
a
.
scalar_
type
(),
"exp_forward_kernel"
,
([
&
]
{
X
=
torch
::
zeros
({
batch_size
,
group_t
::
N
},
a
.
options
());
X
=
torch
::
zeros
({
batch_size
,
group_t
::
N
},
a
.
options
());
exp_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
exp_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
@@ -314,7 +314,7 @@ std::vector<torch::Tensor> exp_backward_gpu(int group_id, torch::Tensor grad, to
...
@@ -314,7 +314,7 @@ std::vector<torch::Tensor> exp_backward_gpu(int group_id, torch::Tensor grad, to
int
batch_size
=
a
.
size
(
0
);
int
batch_size
=
a
.
size
(
0
);
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
a
.
type
(),
"exp_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
a
.
scalar_
type
(),
"exp_backward_kernel"
,
([
&
]
{
exp_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
exp_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
@@ -329,7 +329,7 @@ torch::Tensor log_forward_gpu(int group_id, torch::Tensor X) {
...
@@ -329,7 +329,7 @@ torch::Tensor log_forward_gpu(int group_id, torch::Tensor X) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
a
;
torch
::
Tensor
a
;
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"log_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"log_forward_kernel"
,
([
&
]
{
a
=
torch
::
zeros
({
batch_size
,
group_t
::
K
},
X
.
options
());
a
=
torch
::
zeros
({
batch_size
,
group_t
::
K
},
X
.
options
());
log_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
log_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -344,7 +344,7 @@ std::vector<torch::Tensor> log_backward_gpu(int group_id, torch::Tensor grad, to
...
@@ -344,7 +344,7 @@ std::vector<torch::Tensor> log_backward_gpu(int group_id, torch::Tensor grad, to
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"log_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"log_backward_kernel"
,
([
&
]
{
log_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
log_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -359,7 +359,7 @@ torch::Tensor inv_forward_gpu(int group_id, torch::Tensor X) {
...
@@ -359,7 +359,7 @@ torch::Tensor inv_forward_gpu(int group_id, torch::Tensor X) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
Y
=
torch
::
zeros_like
(
X
);
torch
::
Tensor
Y
=
torch
::
zeros_like
(
X
);
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"inv_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"inv_forward_kernel"
,
([
&
]
{
inv_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
inv_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
Y
.
data_ptr
<
scalar_t
>
(),
Y
.
data_ptr
<
scalar_t
>
(),
...
@@ -373,7 +373,7 @@ std::vector<torch::Tensor> inv_backward_gpu(int group_id, torch::Tensor grad, to
...
@@ -373,7 +373,7 @@ std::vector<torch::Tensor> inv_backward_gpu(int group_id, torch::Tensor grad, to
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"inv_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"inv_backward_kernel"
,
([
&
]
{
inv_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
inv_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -389,7 +389,7 @@ torch::Tensor mul_forward_gpu(int group_id, torch::Tensor X, torch::Tensor Y) {
...
@@ -389,7 +389,7 @@ torch::Tensor mul_forward_gpu(int group_id, torch::Tensor X, torch::Tensor Y) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
Z
=
torch
::
zeros_like
(
X
);
torch
::
Tensor
Z
=
torch
::
zeros_like
(
X
);
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"mul_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"mul_forward_kernel"
,
([
&
]
{
mul_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
mul_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
Y
.
data_ptr
<
scalar_t
>
(),
Y
.
data_ptr
<
scalar_t
>
(),
...
@@ -405,7 +405,7 @@ std::vector<torch::Tensor> mul_backward_gpu(int group_id, torch::Tensor grad, to
...
@@ -405,7 +405,7 @@ std::vector<torch::Tensor> mul_backward_gpu(int group_id, torch::Tensor grad, to
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dY
=
torch
::
zeros
(
Y
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dY
=
torch
::
zeros
(
Y
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"mul_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"mul_backward_kernel"
,
([
&
]
{
mul_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
mul_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -422,7 +422,7 @@ torch::Tensor adj_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
...
@@ -422,7 +422,7 @@ torch::Tensor adj_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"adj_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"adj_forward_kernel"
,
([
&
]
{
adj_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
adj_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
@@ -438,7 +438,7 @@ std::vector<torch::Tensor> adj_backward_gpu(int group_id, torch::Tensor grad, to
...
@@ -438,7 +438,7 @@ std::vector<torch::Tensor> adj_backward_gpu(int group_id, torch::Tensor grad, to
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"adj_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"adj_backward_kernel"
,
([
&
]
{
adj_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
adj_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -456,7 +456,7 @@ torch::Tensor adjT_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
...
@@ -456,7 +456,7 @@ torch::Tensor adjT_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"adjT_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"adjT_forward_kernel"
,
([
&
]
{
adjT_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
adjT_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
@@ -472,7 +472,7 @@ std::vector<torch::Tensor> adjT_backward_gpu(int group_id, torch::Tensor grad, t
...
@@ -472,7 +472,7 @@ std::vector<torch::Tensor> adjT_backward_gpu(int group_id, torch::Tensor grad, t
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
torch
::
Tensor
da
=
torch
::
zeros
(
a
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"adjT_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"adjT_backward_kernel"
,
([
&
]
{
adjT_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
adjT_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -491,7 +491,7 @@ torch::Tensor act_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
...
@@ -491,7 +491,7 @@ torch::Tensor act_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
q
=
torch
::
zeros
(
p
.
sizes
(),
p
.
options
());
torch
::
Tensor
q
=
torch
::
zeros
(
p
.
sizes
(),
p
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"act_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"act_forward_kernel"
,
([
&
]
{
act_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
act_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
p
.
data_ptr
<
scalar_t
>
(),
p
.
data_ptr
<
scalar_t
>
(),
...
@@ -507,7 +507,7 @@ std::vector<torch::Tensor> act_backward_gpu(int group_id, torch::Tensor grad, to
...
@@ -507,7 +507,7 @@ std::vector<torch::Tensor> act_backward_gpu(int group_id, torch::Tensor grad, to
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dp
=
torch
::
zeros
(
p
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dp
=
torch
::
zeros
(
p
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"act_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"act_backward_kernel"
,
([
&
]
{
act_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
act_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -524,7 +524,7 @@ torch::Tensor act4_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
...
@@ -524,7 +524,7 @@ torch::Tensor act4_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
q
=
torch
::
zeros
(
p
.
sizes
(),
p
.
options
());
torch
::
Tensor
q
=
torch
::
zeros
(
p
.
sizes
(),
p
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"act4_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"act4_forward_kernel"
,
([
&
]
{
act4_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
act4_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
p
.
data_ptr
<
scalar_t
>
(),
p
.
data_ptr
<
scalar_t
>
(),
...
@@ -540,7 +540,7 @@ std::vector<torch::Tensor> act4_backward_gpu(int group_id, torch::Tensor grad, t
...
@@ -540,7 +540,7 @@ std::vector<torch::Tensor> act4_backward_gpu(int group_id, torch::Tensor grad, t
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dX
=
torch
::
zeros
(
X
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dp
=
torch
::
zeros
(
p
.
sizes
(),
grad
.
options
());
torch
::
Tensor
dp
=
torch
::
zeros
(
p
.
sizes
(),
grad
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"act4_backward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"act4_backward_kernel"
,
([
&
]
{
act4_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
act4_backward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
grad
.
data_ptr
<
scalar_t
>
(),
grad
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -558,7 +558,7 @@ torch::Tensor as_matrix_forward_gpu(int group_id, torch::Tensor X) {
...
@@ -558,7 +558,7 @@ torch::Tensor as_matrix_forward_gpu(int group_id, torch::Tensor X) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
T4x4
=
torch
::
zeros
({
X
.
size
(
0
),
4
,
4
},
X
.
options
());
torch
::
Tensor
T4x4
=
torch
::
zeros
({
X
.
size
(
0
),
4
,
4
},
X
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"as_matrix_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"as_matrix_forward_kernel"
,
([
&
]
{
as_matrix_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
as_matrix_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
T4x4
.
data_ptr
<
scalar_t
>
(),
T4x4
.
data_ptr
<
scalar_t
>
(),
...
@@ -573,7 +573,7 @@ torch::Tensor orthogonal_projector_gpu(int group_id, torch::Tensor X) {
...
@@ -573,7 +573,7 @@ torch::Tensor orthogonal_projector_gpu(int group_id, torch::Tensor X) {
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
P
;
torch
::
Tensor
P
;
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"orthogonal_projector_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"orthogonal_projector_kernel"
,
([
&
]
{
P
=
torch
::
zeros
({
X
.
size
(
0
),
group_t
::
N
,
group_t
::
N
},
X
.
options
());
P
=
torch
::
zeros
({
X
.
size
(
0
),
group_t
::
N
,
group_t
::
N
},
X
.
options
());
orthogonal_projector_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
orthogonal_projector_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
...
@@ -589,7 +589,7 @@ torch::Tensor jleft_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a)
...
@@ -589,7 +589,7 @@ torch::Tensor jleft_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a)
int
batch_size
=
X
.
size
(
0
);
int
batch_size
=
X
.
size
(
0
);
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
torch
::
Tensor
b
=
torch
::
zeros
(
a
.
sizes
(),
a
.
options
());
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
type
(),
"jleft_forward_kernel"
,
([
&
]
{
DISPATCH_GROUP_AND_FLOATING_TYPES
(
group_id
,
X
.
scalar_
type
(),
"jleft_forward_kernel"
,
([
&
]
{
jleft_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
jleft_forward_kernel
<
group_t
,
scalar_t
><<<
NUM_BLOCKS
(
batch_size
),
NUM_THREADS
>>>
(
X
.
data_ptr
<
scalar_t
>
(),
X
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
a
.
data_ptr
<
scalar_t
>
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment