Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
92082f98
Commit
92082f98
authored
Oct 05, 2018
by
rusty1s
Browse files
faster coalesce if no value provided
parent
3c7253aa
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
79 additions
and
8 deletions
+79
-8
cuda/unique.cpp
cuda/unique.cpp
+14
-0
cuda/unique_kernel.cu
cuda/unique_kernel.cu
+32
-0
setup.py
setup.py
+3
-1
torch_sparse/coalesce.py
torch_sparse/coalesce.py
+13
-7
torch_sparse/utils/__init__.py
torch_sparse/utils/__init__.py
+0
-0
torch_sparse/utils/unique.py
torch_sparse/utils/unique.py
+17
-0
No files found.
cuda/unique.cpp
0 → 100644
View file @
92082f98
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
unique_cuda
(
at
::
Tensor
src
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
unique
(
at
::
Tensor
src
)
{
CHECK_CUDA
(
src
);
return
unique_cuda
(
src
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"unique"
,
&
unique
,
"Unique (CUDA)"
);
}
cuda/unique_kernel.cu
0 → 100644
View file @
92082f98
#include <ATen/ATen.h>
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template
<
typename
scalar_t
>
__global__
void
unique_cuda_kernel
(
scalar_t
*
__restrict__
src
,
uint8_t
*
mask
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
numel
;
i
+=
stride
)
{
if
(
i
==
0
||
src
[
i
]
!=
src
[
i
-
1
])
{
mask
[
i
]
=
1
;
}
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
unique_cuda
(
at
::
Tensor
src
)
{
at
::
Tensor
perm
;
std
::
tie
(
src
,
perm
)
=
src
.
sort
();
auto
mask
=
at
::
zeros
(
src
.
numel
(),
src
.
type
().
toScalarType
(
at
::
kByte
));
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"grid_cuda_kernel"
,
[
&
]
{
unique_cuda_kernel
<
scalar_t
><<<
BLOCKS
(
src
.
numel
()),
THREADS
>>>
(
src
.
data
<
scalar_t
>
(),
mask
.
data
<
uint8_t
>
(),
src
.
numel
());
});
src
=
src
.
masked_select
(
mask
);
perm
=
perm
.
masked_select
(
mask
);
return
std
::
make_tuple
(
src
,
perm
);
}
setup.py
View file @
92082f98
...
@@ -17,7 +17,9 @@ if torch.cuda.is_available():
...
@@ -17,7 +17,9 @@ if torch.cuda.is_available():
'spspmm_cuda'
,
'spspmm_cuda'
,
[
'cuda/spspmm.cpp'
,
'cuda/spspmm_kernel.cu'
],
[
'cuda/spspmm.cpp'
,
'cuda/spspmm_kernel.cu'
],
extra_link_args
=
[
'-lcusparse'
],
extra_link_args
=
[
'-lcusparse'
],
)
),
CUDAExtension
(
'unique_cuda'
,
[
'cuda/unique.cpp'
,
'cuda/unique_kernel.cu'
])
]
]
cmdclass
[
'build_ext'
]
=
BuildExtension
cmdclass
[
'build_ext'
]
=
BuildExtension
...
...
torch_sparse/coalesce.py
View file @
92082f98
import
torch
import
torch
import
torch_scatter
import
torch_scatter
from
.utils.unique
import
unique
def
coalesce
(
index
,
value
,
m
,
n
,
op
=
'add'
,
fill_value
=
0
):
def
coalesce
(
index
,
value
,
m
,
n
,
op
=
'add'
,
fill_value
=
0
):
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
...
@@ -23,16 +25,20 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
...
@@ -23,16 +25,20 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
row
,
col
=
index
row
,
col
=
index
unique
,
inv
=
torch
.
unique
(
row
*
n
+
col
,
sorted
=
True
,
return_inverse
=
True
)
if
value
is
None
:
_
,
perm
=
unique
(
row
*
n
+
col
)
index
=
torch
.
stack
([
row
[
perm
],
col
[
perm
]],
dim
=
0
)
return
index
,
value
uniq
,
inv
=
torch
.
unique
(
row
*
n
+
col
,
sorted
=
True
,
return_inverse
=
True
)
perm
=
torch
.
arange
(
inv
.
size
(
0
),
dtype
=
inv
.
dtype
,
device
=
inv
.
device
)
perm
=
torch
.
arange
(
inv
.
size
(
0
),
dtype
=
inv
.
dtype
,
device
=
inv
.
device
)
perm
=
inv
.
new_empty
(
uniq
ue
.
size
(
0
)).
scatter_
(
0
,
inv
,
perm
)
perm
=
inv
.
new_empty
(
uniq
.
size
(
0
)).
scatter_
(
0
,
inv
,
perm
)
index
=
torch
.
stack
([
row
[
perm
],
col
[
perm
]],
dim
=
0
)
index
=
torch
.
stack
([
row
[
perm
],
col
[
perm
]],
dim
=
0
)
if
value
is
not
None
:
op
=
getattr
(
torch_scatter
,
'scatter_{}'
.
format
(
op
))
op
=
getattr
(
torch_scatter
,
'scatter_{}'
.
format
(
op
))
value
=
op
(
value
,
inv
,
0
,
None
,
perm
.
size
(
0
),
fill_value
)
value
=
op
(
value
,
inv
,
0
,
None
,
perm
.
size
(
0
),
fill_value
)
if
isinstance
(
value
,
tuple
):
if
isinstance
(
value
,
tuple
):
value
=
value
[
0
]
value
=
value
[
0
]
return
index
,
value
return
index
,
value
torch_sparse/utils/__init__.py
0 → 100644
View file @
92082f98
torch_sparse/utils/unique.py
0 → 100644
View file @
92082f98
import
torch
import
numpy
as
np
if
torch
.
cuda
.
is_available
():
import
unique_cuda
def
unique
(
src
):
src
=
src
.
contiguous
().
view
(
-
1
)
if
src
.
is_cuda
:
out
,
perm
=
unique_cuda
.
unique
(
src
)
else
:
out
,
perm
=
np
.
unique
(
src
.
numpy
(),
return_index
=
True
)
out
,
perm
=
torch
.
from_numpy
(
out
),
torch
.
from_numpy
(
perm
)
return
out
,
perm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment