Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
e696cfd6
"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "9ccebe5b50823e638322e84c65a3526ea5d684b5"
Commit
e696cfd6
authored
Jan 13, 2020
by
rusty1s
Browse files
reduce op
parent
5dc4080c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
40 deletions
+70
-40
torch_sparse/reduce.py
torch_sparse/reduce.py
+26
-6
torch_sparse/tensor.py
torch_sparse/tensor.py
+44
-34
No files found.
torch_sparse/reduce.py
View file @
e696cfd6
...
@@ -14,7 +14,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
...
@@ -14,7 +14,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
dims
=
[
dim
]
if
isinstance
(
dim
,
int
)
else
sorted
(
list
(
dim
))
dims
=
[
dim
]
if
isinstance
(
dim
,
int
)
else
sorted
(
list
(
dim
))
assert
dim
[
-
1
]
<
src
.
dim
()
assert
dim
s
[
-
1
]
<
src
.
dim
()
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
...
@@ -30,7 +30,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
...
@@ -30,7 +30,7 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
value
=
src
.
nnz
()
if
reduce
==
'add'
else
1
value
=
src
.
nnz
()
if
reduce
==
'add'
else
1
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
==
0
:
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
==
0
:
# src.has_value()
func
=
getattr
(
torch
,
'sum'
if
reduce
==
'add'
else
reduce
)
func
=
getattr
(
torch
,
'sum'
if
reduce
==
'add'
else
reduce
)
dense_dims
=
dense_dims
[
0
]
if
len
(
dense_dims
)
==
1
else
dense_dims
dense_dims
=
dense_dims
[
0
]
if
len
(
dense_dims
)
==
1
else
dense_dims
value
=
func
(
value
,
dim
=
dense_dims
)
value
=
func
(
value
,
dim
=
dense_dims
)
...
@@ -44,23 +44,43 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
...
@@ -44,23 +44,43 @@ def __reduce__(src, dim=None, reduce='add', deterministic=False):
value
=
func
(
value
,
dim
=
dense_dims
)
value
=
func
(
value
,
dim
=
dense_dims
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
if
sparse_dims
[
0
]
==
0
:
if
sparse_dims
[
0
]
==
1
and
src
.
has_value
()
:
out
=
segment_csr
(
value
,
rowptr
)
out
=
segment_csr
(
value
,
rowptr
)
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
return
out
return
out
if
sparse_dims
[
0
]
==
1
and
(
src
.
storage
.
_csr2csc
or
deterministic
):
if
sparse_dims
[
0
]
==
1
and
not
src
.
has_value
():
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
if
reduce
==
'add'
:
return
src
.
storage
.
rowcount
.
to
(
torch
.
get_default_dtype
())
elif
reduce
==
'min'
or
'max'
:
return
torch
.
ones
(
src
.
size
(
0
),
device
=
src
.
device
),
None
else
:
return
torch
.
ones
(
src
.
size
(
0
),
device
=
src
.
device
)
deterministic
=
src
.
storage
.
_csr2csc
is
not
None
or
deterministic
if
sparse_dims
[
0
]
==
0
and
deterministic
and
src
.
has_value
():
csr2csc
,
colptr
=
src
.
storage
.
csr2csc
,
src
.
storage
.
colptr
csr2csc
,
colptr
=
src
.
storage
.
csr2csc
,
src
.
storage
.
colptr
out
=
segment_csr
(
value
[
csr2csc
],
colptr
)
out
=
segment_csr
(
value
[
csr2csc
],
colptr
)
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
return
out
return
out
if
sparse_dims
[
0
]
==
1
:
if
sparse_dims
[
0
]
==
0
and
src
.
has_value
()
:
func
=
getattr
(
torch_scatter
,
f
'scatter_
{
reduce
}
'
)
func
=
getattr
(
torch_scatter
,
f
'scatter_
{
reduce
}
'
)
out
=
func
(
value
,
col
,
dim
=
0
,
dim_size
=
src
.
sparse_size
(
0
))
out
=
func
(
value
,
col
,
dim
=
0
,
dim_size
=
src
.
sparse_size
(
1
))
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
return
out
return
out
if
sparse_dims
[
0
]
==
0
and
not
src
.
has_value
():
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
if
reduce
==
'add'
:
return
src
.
storage
.
colcount
.
to
(
torch
.
get_default_dtype
())
elif
reduce
==
'min'
or
'max'
:
return
torch
.
ones
(
src
.
size
(
1
),
device
=
src
.
device
),
None
else
:
return
torch
.
ones
(
src
.
size
(
1
),
device
=
src
.
device
)
def
sum
(
src
,
dim
=
None
,
deterministic
=
False
):
def
sum
(
src
,
dim
=
None
,
deterministic
=
False
):
return
__reduce__
(
src
,
dim
,
reduce
=
'add'
,
deterministic
=
deterministic
)
return
__reduce__
(
src
,
dim
,
reduce
=
'add'
,
deterministic
=
deterministic
)
...
...
torch_sparse/tensor.py
View file @
e696cfd6
...
@@ -10,13 +10,14 @@ from torch_sparse.narrow import narrow
...
@@ -10,13 +10,14 @@ from torch_sparse.narrow import narrow
from
torch_sparse.select
import
select
from
torch_sparse.select
import
select
from
torch_sparse.index_select
import
index_select
,
index_select_nnz
from
torch_sparse.index_select
import
index_select
,
index_select_nnz
from
torch_sparse.masked_select
import
masked_select
,
masked_select_nnz
from
torch_sparse.masked_select
import
masked_select
,
masked_select_nnz
import
torch_sparse.reduce
from
torch_sparse.add
import
add
,
add_nnz
from
torch_sparse.add
import
add
,
add_nnz
class
SparseTensor
(
object
):
class
SparseTensor
(
object
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
self
.
storage
=
SparseStorage
(
self
.
storage
=
SparseStorage
(
index
,
value
,
sparse_size
,
index
,
value
,
sparse_size
,
is_sorted
=
is_sorted
)
is_sorted
=
is_sorted
)
@
classmethod
@
classmethod
def
from_storage
(
self
,
storage
):
def
from_storage
(
self
,
storage
):
...
@@ -37,8 +38,8 @@ class SparseTensor(object):
...
@@ -37,8 +38,8 @@ class SparseTensor(object):
@
classmethod
@
classmethod
def
from_torch_sparse_coo_tensor
(
self
,
mat
,
is_sorted
=
False
):
def
from_torch_sparse_coo_tensor
(
self
,
mat
,
is_sorted
=
False
):
return
SparseTensor
(
return
SparseTensor
(
mat
.
_indices
(),
mat
.
_values
(),
mat
.
_indices
(),
mat
.
_values
(),
mat
.
size
()[:
2
],
is_sorted
=
is_sorted
)
mat
.
size
()[:
2
],
is_sorted
=
is_sorted
)
@
classmethod
@
classmethod
def
from_scipy
(
self
,
mat
):
def
from_scipy
(
self
,
mat
):
...
@@ -55,8 +56,8 @@ class SparseTensor(object):
...
@@ -55,8 +56,8 @@ class SparseTensor(object):
value
=
torch
.
from_numpy
(
mat
.
data
)
value
=
torch
.
from_numpy
(
mat
.
data
)
size
=
mat
.
shape
size
=
mat
.
shape
storage
=
SparseStorage
(
storage
=
SparseStorage
(
index
,
value
,
size
,
rowptr
=
rowptr
,
index
,
value
,
size
,
rowptr
=
rowptr
,
colptr
=
colptr
,
is_sorted
=
True
)
colptr
=
colptr
,
is_sorted
=
True
)
return
SparseTensor
.
from_storage
(
storage
)
return
SparseTensor
.
from_storage
(
storage
)
...
@@ -193,8 +194,8 @@ class SparseTensor(object):
...
@@ -193,8 +194,8 @@ class SparseTensor(object):
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
cpu
()))
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
cpu
()))
def
cuda
(
self
,
device
=
None
,
non_blocking
=
False
,
**
kwargs
):
def
cuda
(
self
,
device
=
None
,
non_blocking
=
False
,
**
kwargs
):
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
cuda
(
device
,
non_blocking
,
**
storage
=
self
.
storage
.
apply
(
kwargs
))
lambda
x
:
x
.
cuda
(
device
,
non_blocking
,
**
kwargs
))
return
self
.
from_storage
(
storage
)
return
self
.
from_storage
(
storage
)
@
property
@
property
...
@@ -216,8 +217,8 @@ class SparseTensor(object):
...
@@ -216,8 +217,8 @@ class SparseTensor(object):
if
dtype
==
self
.
dtype
:
if
dtype
==
self
.
dtype
:
return
self
return
self
storage
=
self
.
storage
.
apply_value
(
lambda
x
:
x
.
type
(
storage
=
self
.
storage
.
apply_value
(
dtype
,
non_blocking
,
**
kwargs
))
lambda
x
:
x
.
type
(
dtype
,
non_blocking
,
**
kwargs
))
return
self
.
from_storage
(
storage
)
return
self
.
from_storage
(
storage
)
...
@@ -286,12 +287,9 @@ class SparseTensor(object):
...
@@ -286,12 +287,9 @@ class SparseTensor(object):
def
to_torch_sparse_coo_tensor
(
self
,
dtype
=
None
,
requires_grad
=
False
):
def
to_torch_sparse_coo_tensor
(
self
,
dtype
=
None
,
requires_grad
=
False
):
index
,
value
=
self
.
coo
()
index
,
value
=
self
.
coo
()
return
torch
.
sparse_coo_tensor
(
return
torch
.
sparse_coo_tensor
(
index
,
index
,
value
if
self
.
has_value
()
else
torch
.
ones
(
value
if
self
.
has_value
()
else
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
),
self
.
size
(),
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
self
.
size
(),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
None
):
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
None
):
assert
self
.
dim
()
==
2
assert
self
.
dim
()
==
2
...
@@ -392,6 +390,10 @@ SparseTensor.index_select = index_select
...
@@ -392,6 +390,10 @@ SparseTensor.index_select = index_select
SparseTensor
.
index_select_nnz
=
index_select_nnz
SparseTensor
.
index_select_nnz
=
index_select_nnz
SparseTensor
.
masked_select
=
masked_select
SparseTensor
.
masked_select
=
masked_select
SparseTensor
.
masked_select_nnz
=
masked_select_nnz
SparseTensor
.
masked_select_nnz
=
masked_select_nnz
SparseTensor
.
sum
=
torch_sparse
.
reduce
.
sum
SparseTensor
.
mean
=
torch_sparse
.
reduce
.
mean
SparseTensor
.
min
=
torch_sparse
.
reduce
.
min
SparseTensor
.
max
=
torch_sparse
.
reduce
.
max
SparseTensor
.
add
=
add
SparseTensor
.
add
=
add
SparseTensor
.
add_nnz
=
add_nnz
SparseTensor
.
add_nnz
=
add_nnz
...
@@ -461,30 +463,38 @@ if __name__ == '__main__':
...
@@ -461,30 +463,38 @@ if __name__ == '__main__':
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
# device = 'cpu'
# device = 'cpu'
#
dataset = Reddit('/tmp/Reddit')
dataset
=
Reddit
(
'/tmp/Reddit'
)
dataset
=
Planetoid
(
'/tmp/PubMed'
,
'PubMed'
)
#
dataset = Planetoid('/tmp/PubMed', 'PubMed')
data
=
dataset
[
0
].
to
(
device
)
data
=
dataset
[
0
].
to
(
device
)
# value = torch.randn(data.num_edges, 10)
value
=
torch
.
randn
((
data
.
num_edges
,
8
),
device
=
device
)
mat
=
SparseTensor
(
data
.
edge_index
)
mat
=
SparseTensor
(
data
.
edge_index
,
value
)
perm
=
torch
.
arange
(
data
.
num_nodes
)
print
(
mat
)
perm
=
torch
.
randperm
(
data
.
num_nodes
)
mat1
=
SparseTensor
(
torch
.
tensor
([[
0
,
1
],
[
0
,
1
]]))
mat2
=
SparseTensor
(
torch
.
tensor
([[
0
,
0
,
1
],
[
0
,
1
,
0
]]))
add
(
mat1
,
mat2
)
# print(mat2)
raise
NotImplementedError
for
_
in
range
(
10
):
x
=
torch
.
randn
(
1000
,
1000
,
device
=
device
).
sum
()
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
f
or
_
in
range
(
100
):
t
or
ch
.
cuda
.
synchronize
()
mat
[
perm
]
out
=
mat
.
sum
(
dim
=
1
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
print
(
time
.
perf_counter
()
-
t
)
print
(
out
.
size
())
# perm = torch.arange(data.num_nodes)
# perm = torch.randperm(data.num_nodes)
# mat1 = SparseTensor(torch.tensor([[0, 1], [0, 1]]))
# mat2 = SparseTensor(torch.tensor([[0, 0, 1], [0, 1, 0]]))
# add(mat1, mat2)
# # print(mat2)
# raise NotImplementedError
# for _ in range(10):
# x = torch.randn(1000, 1000, device=device).sum()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# mat[perm]
# torch.cuda.synchronize()
# print(time.perf_counter() - t)
# index = torch.tensor([
# index = torch.tensor([
# [0, 1, 1, 2, 2],
# [0, 1, 1, 2, 2],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment