Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
6e87043a
Commit
6e87043a
authored
Jan 27, 2020
by
rusty1s
Browse files
diag torch script support
parent
b4030755
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
29 deletions
+25
-29
cpu/diag.cpp
cpu/diag.cpp
+6
-7
cuda/diag.cpp
cuda/diag.cpp
+7
-8
cuda/diag_kernel.cu
cuda/diag_kernel.cu
+4
-4
torch_sparse/diag.py
torch_sparse/diag.py
+3
-8
torch_sparse/utils.py
torch_sparse/utils.py
+5
-2
No files found.
cpu/diag.cpp
View file @
6e87043a
#include <torch/
extension
.h>
#include <torch/
script
.h>
#include "compat.h"
#include "compat.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
a
t
::
Tensor
non_diag_mask
(
a
t
::
Tensor
row
,
a
t
::
Tensor
col
,
int64_t
M
,
int64_t
N
,
t
orch
::
Tensor
non_diag_mask
(
t
orch
::
Tensor
row
,
t
orch
::
Tensor
col
,
int64_t
M
,
int64_t
k
)
{
int64_t
N
,
int64_t
k
)
{
CHECK_CPU
(
row
);
CHECK_CPU
(
row
);
CHECK_CPU
(
col
);
CHECK_CPU
(
col
);
...
@@ -15,7 +15,7 @@ at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
...
@@ -15,7 +15,7 @@ at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
mask
=
a
t
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
at
::
kBool
));
auto
mask
=
t
orch
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
at
::
kBool
));
auto
mask_data
=
mask
.
DATA_PTR
<
bool
>
();
auto
mask_data
=
mask
.
DATA_PTR
<
bool
>
();
int64_t
r
,
c
;
int64_t
r
,
c
;
...
@@ -48,6 +48,5 @@ at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
...
@@ -48,6 +48,5 @@ at::Tensor non_diag_mask(at::Tensor row, at::Tensor col, int64_t M, int64_t N,
return
mask
;
return
mask
;
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
static
auto
registry
=
m
.
def
(
"non_diag_mask"
,
&
non_diag_mask
,
"Non-Diagonal Mask (CPU)"
);
torch
::
RegisterOperators
(
"torch_sparse_cpu::non_diag_mask"
,
&
non_diag_mask
);
}
cuda/diag.cpp
View file @
6e87043a
#include <torch/
extension
.h>
#include <torch/
script
.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
a
t
::
Tensor
non_diag_mask_cuda
(
a
t
::
Tensor
row
,
a
t
::
Tensor
col
,
int64_t
M
,
t
orch
::
Tensor
non_diag_mask_cuda
(
t
orch
::
Tensor
row
,
t
orch
::
Tensor
col
,
int64_t
N
,
int64_t
k
);
int64_t
M
,
int64_t
N
,
int64_t
k
);
a
t
::
Tensor
non_diag_mask
(
a
t
::
Tensor
row
,
a
t
::
Tensor
col
,
int64_t
M
,
int64_t
N
,
t
orch
::
Tensor
non_diag_mask
(
t
orch
::
Tensor
row
,
t
orch
::
Tensor
col
,
int64_t
M
,
int64_t
k
)
{
int64_t
N
,
int64_t
k
)
{
CHECK_CUDA
(
row
);
CHECK_CUDA
(
row
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
col
);
return
non_diag_mask_cuda
(
row
,
col
,
M
,
N
,
k
);
return
non_diag_mask_cuda
(
row
,
col
,
M
,
N
,
k
);
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
static
auto
registry
=
torch
::
RegisterOperators
(
m
.
def
(
"non_diag_mask"
,
&
non_diag_mask
,
"Non-Diagonal Mask (CUDA)"
);
"torch_sparse_cuda::non_diag_mask"
,
&
non_diag_mask
);
}
cuda/diag_kernel.cu
View file @
6e87043a
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "compat.cuh"
#include "compat.cuh"
...
@@ -38,15 +38,15 @@ __global__ void non_diag_mask_kernel(const int64_t *row_data,
...
@@ -38,15 +38,15 @@ __global__ void non_diag_mask_kernel(const int64_t *row_data,
}
}
}
}
a
t
::
Tensor
non_diag_mask_cuda
(
a
t
::
Tensor
row
,
a
t
::
Tensor
col
,
int64_t
M
,
t
orch
::
Tensor
non_diag_mask_cuda
(
t
orch
::
Tensor
row
,
t
orch
::
Tensor
col
,
int64_t
N
,
int64_t
k
)
{
int64_t
M
,
int64_t
N
,
int64_t
k
)
{
int64_t
E
=
row
.
size
(
0
);
int64_t
E
=
row
.
size
(
0
);
int64_t
num_diag
=
k
<
0
?
std
::
min
(
M
+
k
,
N
)
:
std
::
min
(
M
,
N
-
k
);
int64_t
num_diag
=
k
<
0
?
std
::
min
(
M
+
k
,
N
)
:
std
::
min
(
M
,
N
-
k
);
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
mask
=
a
t
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
at
::
kBool
));
auto
mask
=
t
orch
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
at
::
kBool
));
auto
mask_data
=
mask
.
DATA_PTR
<
bool
>
();
auto
mask_data
=
mask
.
DATA_PTR
<
bool
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
torch_sparse/diag.py
View file @
6e87043a
import
torch
import
torch
from
torch_sparse
import
diag_cpu
from
.utils
import
ext
try
:
from
torch_sparse
import
diag_cuda
except
ImportError
:
diag_cuda
=
None
def
remove_diag
(
src
,
k
=
0
):
def
remove_diag
(
src
,
k
=
0
):
...
@@ -44,8 +39,8 @@ def set_diag(src, values=None, k=0):
...
@@ -44,8 +39,8 @@ def set_diag(src, values=None, k=0):
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
func
=
diag_cuda
if
row
.
is_cuda
else
diag_cpu
mask
=
ext
(
row
.
is_cuda
).
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
mask
=
func
.
non_diag_mask
(
row
,
col
,
src
.
size
(
0
),
src
.
size
(
1
),
k
)
k
)
inv_mask
=
~
mask
inv_mask
=
~
mask
start
,
num_diag
=
-
k
if
k
<
0
else
0
,
mask
.
numel
()
-
row
.
numel
()
start
,
num_diag
=
-
k
if
k
<
0
else
0
,
mask
.
numel
()
-
row
.
numel
()
...
...
torch_sparse/utils.py
View file @
6e87043a
import
torch
import
torch
torch
.
ops
.
load_library
(
'torch_sparse/convert_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/convert_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/diag_cpu.so'
)
try
:
try
:
torch
.
ops
.
load_library
(
'torch_sparse/convert_cuda.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/convert_cuda.so'
)
except
OSError
:
torch
.
ops
.
load_library
(
'torch_sparse/diag_cuda.so'
)
pass
except
OSError
as
e
:
if
torch
.
cuda
.
is_available
():
raise
e
def
ext
(
is_cuda
):
def
ext
(
is_cuda
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment