Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
62891fa0
"...text-generation-inference.git" did not exist on "0d96468ebb1ca0141d7a23b2fdfcef9a7ef7bb81"
Commit
62891fa0
authored
Dec 20, 2019
by
rusty1s
Browse files
arange interleave implementation within PyTorch
parent
e61e3d45
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
70 additions
and
88 deletions
+70
-88
cpu/arange_interleave.cpp
cpu/arange_interleave.cpp
+0
-30
setup.py
setup.py
+21
-16
torch_sparse/index_select.py
torch_sparse/index_select.py
+25
-31
torch_sparse/tensor.py
torch_sparse/tensor.py
+24
-11
No files found.
cpu/arange_interleave.cpp
deleted
100644 → 0
View file @
e61e3d45
#include <torch/extension.h>
#include "compat.h"
at
::
Tensor
arange_interleave
(
at
::
Tensor
start
,
at
::
Tensor
repeat
)
{
auto
count
=
repeat
.
sum
().
DATA_PTR
<
int64_t
>
()[
0
];
auto
out
=
at
::
empty
(
count
,
start
.
options
());
auto
repeat_data
=
repeat
.
DATA_PTR
<
int64_t
>
();
AT_DISPATCH_ALL_TYPES
(
start
.
scalar_type
(),
"arange_interleave"
,
[
&
]
{
auto
start_data
=
start
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
int
i
=
0
;
for
(
int
start_idx
=
0
;
start_idx
<
start
.
size
(
0
);
start_idx
++
)
{
scalar_t
init
=
start_data
[
start_idx
];
for
(
scalar_t
rep_idx
=
0
;
rep_idx
<
repeat_data
[
start_idx
];
rep_idx
++
)
{
out_data
[
i
]
=
init
+
rep_idx
;
i
++
;
}
}
});
return
out
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"arange_interleave"
,
&
arange_interleave
,
"Arange Interleave (CPU)"
);
}
setup.py
View file @
62891fa0
...
@@ -12,11 +12,11 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
...
@@ -12,11 +12,11 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
ext_modules
=
[
ext_modules
=
[
CppExtension
(
'torch_sparse.arange_interleave_cpu'
,
CppExtension
(
[
'cpu/arange_interleave.cpp'
]
,
'torch_sparse.spspmm_cpu'
,
extra_compile_args
=
extra_compile_args
)
,
[
'cpu/spspmm.cpp'
]
,
CppExtension
(
'torch_sparse.spspmm_cpu'
,
[
'cpu/spspmm.cpp'
]
,
extra_compile_args
=
extra_compile_args
,
extra_compile_args
=
extra_compile_args
),
),
]
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
...
@@ -33,17 +33,22 @@ if CUDA_HOME is not None and GPU:
...
@@ -33,17 +33,22 @@ if CUDA_HOME is not None and GPU:
extra_link_args
=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
extra_link_args
=
[
'-lcusparse'
,
'-l'
,
'cusparse'
]
ext_modules
+=
[
ext_modules
+=
[
CUDAExtension
(
'torch_sparse.spmm_cuda'
,
CUDAExtension
(
[
'cuda/spmm.cpp'
,
'cuda/spmm_kernel.cu'
],
'torch_sparse.spmm_cuda'
,
extra_link_args
=
extra_link_args
,
[
'cuda/spmm.cpp'
,
'cuda/spmm_kernel.cu'
],
extra_compile_args
=
extra_compile_args
),
extra_compile_args
=
extra_compile_args
,
CUDAExtension
(
'torch_sparse.spspmm_cuda'
,
),
[
'cuda/spspmm.cpp'
,
'cuda/spspmm_kernel.cu'
],
CUDAExtension
(
extra_link_args
=
extra_link_args
,
'torch_sparse.spspmm_cuda'
,
extra_compile_args
=
extra_compile_args
),
[
'cuda/spspmm.cpp'
,
'cuda/spspmm_kernel.cu'
],
CUDAExtension
(
'torch_sparse.unique_cuda'
,
extra_link_args
=
extra_link_args
,
[
'cuda/unique.cpp'
,
'cuda/unique_kernel.cu'
],
extra_compile_args
=
extra_compile_args
,
extra_compile_args
=
extra_compile_args
),
),
CUDAExtension
(
'torch_sparse.unique_cuda'
,
[
'cuda/unique.cpp'
,
'cuda/unique_kernel.cu'
],
extra_compile_args
=
extra_compile_args
,
),
]
]
__version__
=
'0.4.3'
__version__
=
'0.4.3'
...
...
torch_sparse/index_select.py
View file @
62891fa0
import
torch
import
torch
from
torch_sparse.storage
import
get_layout
from
torch_sparse.storage
import
get_layout
import
torch_sparse.arange_interleave_cpu
as
arange_interleave_cpu
def
arange_interleave
(
start
,
repeat
):
assert
start
.
device
==
repeat
.
device
assert
repeat
.
dtype
==
torch
.
long
assert
start
.
dim
()
==
1
assert
repeat
.
dim
()
==
1
assert
start
.
numel
()
==
repeat
.
numel
()
if
start
.
is_cuda
:
raise
NotImplementedError
return
arange_interleave_cpu
.
arange_interleave
(
start
,
repeat
)
def
index_select
(
src
,
dim
,
idx
):
def
index_select
(
src
,
dim
,
idx
):
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
dim
=
src
.
dim
()
+
dim
if
dim
<
0
else
dim
assert
idx
.
dim
()
==
1
assert
idx
.
dim
()
==
1
idx
=
idx
.
to
(
src
.
device
)
if
dim
==
0
:
if
dim
==
0
:
(
_
,
col
),
value
=
src
.
coo
()
(
row
,
col
),
value
=
src
.
coo
()
rowcount
=
src
.
storage
.
rowcount
rowcount
=
src
.
storage
.
rowcount
rowptr
=
src
.
storage
.
rowptr
old_
rowptr
=
src
.
storage
.
rowptr
rowcount
=
rowcount
[
idx
]
rowcount
=
rowcount
[
idx
]
tmp
=
torch
.
arange
(
rowcount
.
size
(
0
),
device
=
rowcount
.
device
)
tmp
=
torch
.
arange
(
rowcount
.
size
(
0
),
device
=
rowcount
.
device
)
row
=
tmp
.
repeat_interleave
(
rowcount
)
row
=
tmp
.
repeat_interleave
(
rowcount
)
perm
=
arange_interleave
(
rowptr
[
idx
],
rowcount
)
# Creates an "arange interleave" tensor of col indices.
rowptr
=
torch
.
cat
([
row
.
new_zeros
(
1
),
rowcount
.
cumsum
(
0
)],
dim
=
0
)
perm
=
torch
.
arange
(
row
.
size
(
0
),
device
=
row
.
device
)
perm
+=
(
old_rowptr
[
idx
]
-
rowptr
[:
-
1
])[
row
]
col
=
col
[
perm
]
col
=
col
[
perm
]
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
...
@@ -38,17 +30,23 @@ def index_select(src, dim, idx):
...
@@ -38,17 +30,23 @@ def index_select(src, dim, idx):
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
sparse_size
=
torch
.
Size
([
rowcount
.
size
(
0
),
src
.
sparse_size
(
1
)])
storage
=
src
.
storage
.
__class__
(
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
index
,
value
,
sparse_size
,
rowcount
=
rowcount
,
is_sorted
=
True
)
rowcount
=
rowcount
,
rowptr
=
rowptr
,
is_sorted
=
True
)
elif
dim
==
1
:
elif
dim
==
1
:
colptr
,
row
,
value
=
src
.
csc
()
old_
colptr
,
row
,
value
=
src
.
csc
()
colcount
=
src
.
storage
.
colcount
colcount
=
src
.
storage
.
colcount
colcount
=
colcount
[
idx
]
colcount
=
colcount
[
idx
]
tmp
=
torch
.
arange
(
colcount
.
size
(
0
),
device
=
row
.
device
)
tmp
=
torch
.
arange
(
colcount
.
size
(
0
),
device
=
row
.
device
)
col
=
tmp
.
repeat_interleave
(
colcount
)
col
=
tmp
.
repeat_interleave
(
colcount
)
perm
=
arange_interleave
(
colptr
[
idx
],
colcount
)
# Creates an "arange interleave" tensor of row indices.
colptr
=
torch
.
cat
([
col
.
new_zeros
(
1
),
colcount
.
cumsum
(
0
)],
dim
=
0
)
perm
=
torch
.
arange
(
col
.
size
(
0
),
device
=
col
.
device
)
perm
+=
(
old_colptr
[
idx
]
-
colptr
[:
-
1
])[
col
]
row
=
row
[
perm
]
row
=
row
[
perm
]
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)[:,
csc2csr
]
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)[:,
csc2csr
]
...
@@ -58,17 +56,13 @@ def index_select(src, dim, idx):
...
@@ -58,17 +56,13 @@ def index_select(src, dim, idx):
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
colcount
.
size
(
0
)])
storage
=
src
.
storage
.
__class__
(
storage
=
src
.
storage
.
__class__
(
index
,
value
,
sparse_size
,
index
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
value
,
is_sorted
=
True
)
sparse_size
,
colcount
=
colcount
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
else
:
else
:
storage
=
src
.
storage
.
apply_value
(
lambda
x
:
x
.
index_select
(
storage
=
src
.
storage
.
apply_value
(
dim
-
1
,
idx
))
lambda
x
:
x
.
index_select
(
dim
-
1
,
idx
))
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
...
@@ -86,7 +80,7 @@ def index_select_nnz(src, idx, layout=None):
...
@@ -86,7 +80,7 @@ def index_select_nnz(src, idx, layout=None):
value
=
value
[
idx
]
value
=
value
[
idx
]
# There is no other information we can maintain...
# There is no other information we can maintain...
storage
=
src
.
storage
.
__class__
(
storage
=
src
.
storage
.
__class__
(
index
,
value
,
src
.
sparse_size
(),
index
,
value
,
src
.
sparse_size
(),
is_sorted
=
True
)
is_sorted
=
True
)
return
src
.
from_storage
(
storage
)
return
src
.
from_storage
(
storage
)
torch_sparse/tensor.py
View file @
62891fa0
...
@@ -487,23 +487,36 @@ if __name__ == '__main__':
...
@@ -487,23 +487,36 @@ if __name__ == '__main__':
import
time
# noqa
import
time
# noqa
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
# device = 'cpu'
# dataset = Reddit('/tmp/Reddit')
# dataset = Reddit('/tmp/Reddit')
dataset
=
Planetoid
(
'/tmp/
Cora'
,
'Cora
'
)
dataset
=
Planetoid
(
'/tmp/
PubMed'
,
'PubMed
'
)
data
=
dataset
[
0
].
to
(
device
)
data
=
dataset
[
0
].
to
(
device
)
value
=
torch
.
randn
(
data
.
num_edges
,
10
)
# value = torch.randn(data.num_edges, 10)
mat
=
SparseTensor
(
data
.
edge_index
,
value
)
mat
=
SparseTensor
(
data
.
edge_index
)
perm
=
torch
.
arange
(
data
.
num_nodes
)
perm
=
torch
.
randperm
(
data
.
num_nodes
)
index
=
torch
.
tensor
([
for
_
in
range
(
10
):
[
0
,
1
,
1
,
2
,
2
],
x
=
torch
.
randn
(
1000
,
1000
,
device
=
device
).
sum
()
[
1
,
2
,
2
,
2
,
3
],
])
value
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
5
])
mat
=
SparseTensor
(
index
,
value
)
torch
.
cuda
.
synchronize
()
print
(
mat
)
t
=
time
.
perf_counter
()
print
(
mat
.
coalesce
())
for
_
in
range
(
100
):
mat
[
perm
]
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
# index = torch.tensor([
# [0, 1, 1, 2, 2],
# [1, 2, 2, 2, 3],
# ])
# value = torch.tensor([1, 2, 3, 4, 5])
# mat = SparseTensor(index, value)
# print(mat)
# print(mat.coalesce())
# index = torch.tensor([0, 1, 2])
# index = torch.tensor([0, 1, 2])
# mask = torch.zeros(data.num_nodes, dtype=torch.bool)
# mask = torch.zeros(data.num_nodes, dtype=torch.bool)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment