Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
918b1163
Commit
918b1163
authored
Jan 25, 2020
by
rusty1s
Browse files
conversion utilities
parent
2ae73b17
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
141 additions
and
61 deletions
+141
-61
cpu/convert.cpp
cpu/convert.cpp
+51
-0
cuda/convert.cpp
cuda/convert.cpp
+21
-0
cuda/convert_kernel.cu
cuda/convert_kernel.cu
+56
-0
cuda/rowptr.cpp
cuda/rowptr.cpp
+0
-14
cuda/rowptr_kernel.cu
cuda/rowptr_kernel.cu
+0
-37
torch_sparse/storage.py
torch_sparse/storage.py
+12
-9
torch_sparse/tensor.py
torch_sparse/tensor.py
+1
-1
No files found.
cpu/
rowptr
.cpp
→
cpu/
convert
.cpp
View file @
918b1163
...
@@ -4,20 +4,19 @@
...
@@ -4,20 +4,19 @@
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at
::
Tensor
rowptr
(
at
::
Tensor
row
,
int64_t
M
)
{
at
::
Tensor
ind2ptr
(
at
::
Tensor
ind
,
int64_t
M
)
{
CHECK_CPU
(
row
);
CHECK_CPU
(
ind
);
AT_ASSERTM
(
row
.
dim
()
==
1
,
"Row needs to be one-dimensional"
);
auto
out
=
at
::
empty
(
M
+
1
,
ind
.
options
());
auto
ind_data
=
ind
.
DATA_PTR
<
int64_t
>
();
auto
out
=
at
::
empty
(
M
+
1
,
row
.
options
());
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
int64_t
>
();
int64_t
numel
=
row
.
numel
(),
idx
=
row_data
[
0
],
next_idx
;
int64_t
numel
=
ind
.
numel
(),
idx
=
ind_data
[
0
],
next_idx
;
for
(
int64_t
i
=
0
;
i
<=
idx
;
i
++
)
for
(
int64_t
i
=
0
;
i
<=
idx
;
i
++
)
out_data
[
i
]
=
0
;
out_data
[
i
]
=
0
;
for
(
int64_t
i
=
0
;
i
<
numel
-
1
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
numel
-
1
;
i
++
)
{
next_idx
=
row
_data
[
i
+
1
];
next_idx
=
ind
_data
[
i
+
1
];
for
(
int64_t
j
=
idx
;
j
<
next_idx
;
j
++
)
for
(
int64_t
j
=
idx
;
j
<
next_idx
;
j
++
)
out_data
[
j
+
1
]
=
i
+
1
;
out_data
[
j
+
1
]
=
i
+
1
;
idx
=
next_idx
;
idx
=
next_idx
;
...
@@ -29,6 +28,24 @@ at::Tensor rowptr(at::Tensor row, int64_t M) {
...
@@ -29,6 +28,24 @@ at::Tensor rowptr(at::Tensor row, int64_t M) {
return
out
;
return
out
;
}
}
at
::
Tensor
ptr2ind
(
at
::
Tensor
ptr
,
int64_t
E
)
{
CHECK_CPU
(
ptr
);
auto
out
=
at
::
empty
(
E
,
ptr
.
options
());
auto
ptr_data
=
ptr
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
int64_t
>
();
int64_t
idx
=
ptr_data
[
0
],
next_idx
;
for
(
int64_t
i
=
0
;
i
<
ptr
.
numel
()
-
1
;
i
++
)
{
next_idx
=
ptr_data
[
i
+
1
];
for
(
int64_t
e
=
idx
;
e
<
next_idx
;
e
++
)
out_data
[
e
]
=
i
;
idx
=
next_idx
;
}
return
out
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"rowptr"
,
&
rowptr
,
"Rowptr (CPU)"
);
m
.
def
(
"ind2ptr"
,
&
ind2ptr
,
"Ind2Ptr (CPU)"
);
m
.
def
(
"ptr2ind"
,
&
ptr2ind
,
"Ptr2Ind (CPU)"
);
}
}
cuda/convert.cpp
0 → 100644
View file @
918b1163
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
ind2ptr_cuda
(
at
::
Tensor
ind
,
int64_t
M
);
at
::
Tensor
ptr2ind_cuda
(
at
::
Tensor
ptr
,
int64_t
E
);
at
::
Tensor
ind2ptr
(
at
::
Tensor
ind
,
int64_t
M
)
{
CHECK_CUDA
(
ind
);
return
ind2ptr_cuda
(
ind
,
M
);
}
at
::
Tensor
ptr2ind
(
at
::
Tensor
ptr
,
int64_t
E
)
{
CHECK_CUDA
(
ptr
);
return
ptr2ind_cuda
(
ptr
,
E
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"ind2ptr"
,
&
ind2ptr
,
"Ind2Ptr (CUDA)"
);
m
.
def
(
"ptr2ind"
,
&
ptr2ind
,
"Ptr2Ind (CUDA)"
);
}
cuda/convert_kernel.cu
0 → 100644
View file @
918b1163
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "compat.cuh"
#define THREADS 1024
__global__
void
ind2ptr_kernel
(
const
int64_t
*
ind_data
,
int64_t
*
out_data
,
int64_t
M
,
int64_t
numel
)
{
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
thread_idx
==
0
)
{
for
(
int64_t
i
=
0
;
i
<=
ind_data
[
0
];
i
++
)
out_data
[
i
]
=
0
;
}
else
if
(
thread_idx
<
numel
)
{
for
(
int64_t
i
=
ind_data
[
thread_idx
-
1
];
i
<
ind_data
[
thread_idx
];
i
++
)
out_data
[
i
+
1
]
=
thread_idx
;
}
else
if
(
thread_idx
==
numel
)
{
for
(
int64_t
i
=
ind_data
[
numel
-
1
]
+
1
;
i
<
M
+
1
;
i
++
)
out_data
[
i
]
=
numel
;
}
}
at
::
Tensor
ind2ptr_cuda
(
at
::
Tensor
ind
,
int64_t
M
)
{
auto
out
=
at
::
empty
(
M
+
1
,
ind
.
options
());
auto
ind_data
=
ind
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
ind2ptr_kernel
<<<
(
ind
.
numel
()
+
2
+
THREADS
-
1
)
/
THREADS
,
THREADS
,
0
,
stream
>>>
(
ind_data
,
out_data
,
M
,
ind
.
numel
());
return
out
;
}
__global__
void
ptr2ind_kernel
(
const
int64_t
*
ptr_data
,
int64_t
*
out_data
,
int64_t
E
,
int64_t
numel
)
{
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
int64_t
idx
=
ptr_data
[
thread_idx
],
next_idx
=
ptr_data
[
thread_idx
+
1
];
for
(
int64_t
i
=
idx
;
i
<
next_idx
;
i
++
)
{
out_data
[
i
]
=
thread_idx
;
}
}
}
at
::
Tensor
ptr2ind_cuda
(
at
::
Tensor
ptr
,
int64_t
E
)
{
auto
out
=
at
::
empty
(
E
,
ptr
.
options
());
auto
ptr_data
=
ptr
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
ptr2ind_kernel
<<<
(
ptr
.
numel
()
+
THREADS
-
1
)
/
THREADS
,
THREADS
,
0
,
stream
>>>
(
ptr_data
,
out_data
,
E
,
ptr
.
numel
());
return
out
;
}
cuda/rowptr.cpp
deleted
100644 → 0
View file @
2ae73b17
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
rowptr_cuda
(
at
::
Tensor
row
,
int64_t
M
);
at
::
Tensor
rowptr
(
at
::
Tensor
row
,
int64_t
M
)
{
CHECK_CUDA
(
row
);
return
rowptr_cuda
(
row
,
M
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"rowptr"
,
&
rowptr
,
"Rowptr (CUDA)"
);
}
cuda/rowptr_kernel.cu
deleted
100644 → 0
View file @
2ae73b17
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "compat.cuh"
#define THREADS 1024
__global__
void
rowptr_kernel
(
const
int64_t
*
row_data
,
int64_t
*
out_data
,
int64_t
M
,
int64_t
numel
)
{
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
thread_idx
==
0
)
{
for
(
int64_t
i
=
0
;
i
<=
row_data
[
0
];
i
++
)
out_data
[
i
]
=
0
;
}
else
if
(
thread_idx
<
numel
)
{
for
(
int64_t
i
=
row_data
[
thread_idx
-
1
];
i
<
row_data
[
thread_idx
];
i
++
)
out_data
[
i
+
1
]
=
thread_idx
;
}
else
if
(
thread_idx
==
numel
)
{
for
(
int64_t
i
=
row_data
[
numel
-
1
]
+
1
;
i
<
M
+
1
;
i
++
)
out_data
[
i
]
=
numel
;
}
}
at
::
Tensor
rowptr_cuda
(
at
::
Tensor
row
,
int64_t
M
)
{
AT_ASSERTM
(
row
.
dim
()
==
1
,
"Row needs to be one-dimensional"
);
auto
out
=
at
::
empty
(
M
+
1
,
row
.
options
());
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
rowptr_kernel
<<<
(
row
.
numel
()
+
2
+
THREADS
-
1
)
/
THREADS
,
THREADS
,
0
,
stream
>>>
(
row_data
,
out_data
,
M
,
row
.
numel
());
return
out
;
}
torch_sparse/storage.py
View file @
918b1163
...
@@ -3,12 +3,12 @@ import warnings
...
@@ -3,12 +3,12 @@ import warnings
import
torch
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_sparse
import
rowptr
_cpu
from
torch_sparse
import
convert
_cpu
try
:
try
:
from
torch_sparse
import
rowptr
_cuda
from
torch_sparse
import
convert
_cuda
except
ImportError
:
except
ImportError
:
rowptr
_cuda
=
None
convert
_cuda
=
None
__cache__
=
{
'enabled'
:
True
}
__cache__
=
{
'enabled'
:
True
}
...
@@ -159,8 +159,8 @@ class SparseStorage(object):
...
@@ -159,8 +159,8 @@ class SparseStorage(object):
@
property
@
property
def
row
(
self
):
def
row
(
self
):
if
self
.
_row
is
None
:
if
self
.
_row
is
None
:
# TODO
func
=
convert_cuda
if
self
.
rowptr
.
is_cuda
else
convert_cpu
pass
self
.
_row
=
func
.
ptr2ind
(
self
.
rowptr
,
self
.
nnz
())
return
self
.
_row
return
self
.
_row
def
has_rowptr
(
self
):
def
has_rowptr
(
self
):
...
@@ -169,8 +169,8 @@ class SparseStorage(object):
...
@@ -169,8 +169,8 @@ class SparseStorage(object):
@
property
@
property
def
rowptr
(
self
):
def
rowptr
(
self
):
if
self
.
_rowptr
is
None
:
if
self
.
_rowptr
is
None
:
func
=
rowptr
_cuda
if
self
.
row
.
is_cuda
else
rowptr
_cpu
func
=
convert
_cuda
if
self
.
row
.
is_cuda
else
convert
_cpu
self
.
_rowptr
=
func
.
row
ptr
(
self
.
row
,
self
.
sparse_size
[
0
])
self
.
_rowptr
=
func
.
ind2
ptr
(
self
.
row
,
self
.
sparse_size
[
0
])
return
self
.
_rowptr
return
self
.
_rowptr
@
property
@
property
...
@@ -258,6 +258,9 @@ class SparseStorage(object):
...
@@ -258,6 +258,9 @@ class SparseStorage(object):
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
nnz
(
self
):
return
self
.
col
.
numel
()
def
has_rowcount
(
self
):
def
has_rowcount
(
self
):
return
self
.
_rowcount
is
not
None
return
self
.
_rowcount
is
not
None
...
@@ -271,8 +274,8 @@ class SparseStorage(object):
...
@@ -271,8 +274,8 @@ class SparseStorage(object):
@
cached_property
@
cached_property
def
colptr
(
self
):
def
colptr
(
self
):
if
self
.
has_csr2csc
():
if
self
.
has_csr2csc
():
func
=
rowptr
_cuda
if
self
.
col
.
is_cuda
else
rowptr
_cpu
func
=
convert
_cuda
if
self
.
col
.
is_cuda
else
convert
_cpu
return
func
.
row
ptr
(
self
.
col
[
self
.
csr2csc
],
self
.
sparse_size
[
1
])
return
func
.
ind2
ptr
(
self
.
col
[
self
.
csr2csc
],
self
.
sparse_size
[
1
])
else
:
else
:
colptr
=
self
.
col
.
new_zeros
(
self
.
sparse_size
[
1
]
+
1
)
colptr
=
self
.
col
.
new_zeros
(
self
.
sparse_size
[
1
]
+
1
)
torch
.
cumsum
(
self
.
colcount
,
dim
=
0
,
out
=
colptr
[
1
:])
torch
.
cumsum
(
self
.
colcount
,
dim
=
0
,
out
=
colptr
[
1
:])
...
...
torch_sparse/tensor.py
View file @
918b1163
...
@@ -178,7 +178,7 @@ class SparseTensor(object):
...
@@ -178,7 +178,7 @@ class SparseTensor(object):
return
self
.
size
()
return
self
.
size
()
def
nnz
(
self
):
def
nnz
(
self
):
return
self
.
storage
.
index
.
size
(
1
)
return
self
.
storage
.
nnz
(
)
def
density
(
self
):
def
density
(
self
):
return
self
.
nnz
()
/
(
self
.
sparse_size
(
0
)
*
self
.
sparse_size
(
1
))
return
self
.
nnz
()
/
(
self
.
sparse_size
(
0
)
*
self
.
sparse_size
(
1
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment