Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
3c6dbfa1
Commit
3c6dbfa1
authored
Jan 31, 2020
by
rusty1s
Browse files
reduction
parent
2515ce6d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
95 additions
and
105 deletions
+95
-105
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-0
torch_sparse/add.py
torch_sparse/add.py
+2
-4
torch_sparse/index_select.py
torch_sparse/index_select.py
+2
-4
torch_sparse/mul.py
torch_sparse/mul.py
+2
-4
torch_sparse/reduce.py
torch_sparse/reduce.py
+85
-83
torch_sparse/storage.py
torch_sparse/storage.py
+3
-5
torch_sparse/tensor.py
torch_sparse/tensor.py
+0
-5
No files found.
torch_sparse/__init__.py
View file @
3c6dbfa1
...
...
@@ -45,3 +45,4 @@ from .masked_select import masked_select, masked_select_nnz
from
.diag
import
set_diag
,
remove_diag
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
from
.reduce
import
sum
,
mean
,
min
,
max
torch_sparse/add.py
View file @
3c6dbfa1
...
...
@@ -9,8 +9,7 @@ from torch_sparse.tensor import SparseTensor
def
add
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
pass
elif
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
...
...
@@ -30,8 +29,7 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
def
add_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
pass
elif
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
...
...
torch_sparse/index_select.py
View file @
3c6dbfa1
...
...
@@ -25,8 +25,7 @@ def index_select(src: SparseTensor, dim: int,
device
=
col
.
device
).
repeat_interleave
(
rowcount
)
perm
=
torch
.
arange
(
row
.
size
(
0
),
device
=
row
.
device
)
# TODO
# perm += gather_csr(old_rowptr[idx] - rowptr[:-1], rowptr)
perm
+=
gather_csr
(
old_rowptr
[
idx
]
-
rowptr
[:
-
1
],
rowptr
)
col
=
col
[
perm
]
...
...
@@ -54,8 +53,7 @@ def index_select(src: SparseTensor, dim: int,
device
=
row
.
device
).
repeat_interleave
(
colcount
)
perm
=
torch
.
arange
(
col
.
size
(
0
),
device
=
col
.
device
)
# TODO
# perm += gather_csr(old_colptr[idx] - colptr[:-1], colptr)
perm
+=
gather_csr
(
old_colptr
[
idx
]
-
colptr
[:
-
1
],
colptr
)
row
=
row
[
perm
]
csc2csr
=
(
idx
.
size
(
0
)
*
row
+
col
).
argsort
()
...
...
torch_sparse/mul.py
View file @
3c6dbfa1
...
...
@@ -9,8 +9,7 @@ from torch_sparse.tensor import SparseTensor
def
mul
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
pass
elif
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
...
...
@@ -30,8 +29,7 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
def
mul_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
# TODO
# other = gather_csr(other.squeeze(1), rowptr)
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
pass
elif
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
...
...
torch_sparse/reduce.py
View file @
3c6dbfa1
from
typing
import
Optional
import
torch
import
torch_scatter
from
torch_scatter
import
segment_csr
def
reduction
(
src
,
dim
=
None
,
reduce
=
'sum'
,
deterministic
=
False
):
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
if
dim
is
None
and
src
.
has_value
():
return
getattr
(
torch
,
reduce
)(
src
.
storage
.
value
)
if
dim
is
None
and
not
src
.
has_value
():
value
=
src
.
nnz
()
if
reduce
in
[
'sum'
,
'add'
]
else
1
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
dims
=
[
dim
]
if
isinstance
(
dim
,
int
)
else
dim
dims
=
sorted
([
src
.
dim
()
+
dim
if
dim
<
0
else
dim
for
dim
in
dims
])
assert
dims
[
-
1
]
<
src
.
dim
()
rowptr
,
col
,
value
=
src
.
csr
()
sparse_dims
=
tuple
(
set
([
d
for
d
in
dims
if
d
<
2
]))
dense_dims
=
tuple
(
set
([
d
-
1
for
d
in
dims
if
d
>
1
]))
if
len
(
sparse_dims
)
==
2
and
src
.
has_value
():
return
getattr
(
torch
,
reduce
)(
value
,
dim
=
(
0
,
)
+
dense_dims
)
if
len
(
sparse_dims
)
==
2
and
not
src
.
has_value
():
value
=
src
.
nnz
()
if
reduce
in
[
'sum'
,
'add'
]
else
1
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
from
torch_scatter
import
scatter
,
segment_csr
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
reduction
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
,
reduce
:
str
=
'sum'
)
->
torch
.
Tensor
:
value
=
src
.
storage
.
value
()
if
dim
is
None
:
if
value
is
not
None
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
return
value
.
sum
()
elif
reduce
==
'mean'
:
return
value
.
mean
()
elif
reduce
==
'min'
:
return
value
.
min
()
elif
reduce
==
'max'
:
return
value
.
max
()
else
:
raise
ValueError
else
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
return
torch
.
tensor
(
src
.
nnz
(),
dtype
=
src
.
dtype
(),
device
=
src
.
device
())
elif
reduce
==
'mean'
or
reduce
==
'min'
or
reduce
==
'max'
:
return
torch
.
tensor
(
1
,
dtype
=
src
.
dtype
(),
device
=
src
.
device
())
else
:
raise
ValueError
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
==
0
:
# src.has_value()
dense_dims
=
dense_dims
[
0
]
if
len
(
dense_dims
)
==
1
else
dense_dims
value
=
getattr
(
torch
,
reduce
)(
value
,
dim
=
dense_dims
)
if
isinstance
(
value
,
tuple
):
return
(
src
.
set_value
(
value
[
0
],
layout
=
'csr'
),
)
+
value
[
1
:]
return
src
.
set_value
(
value
,
layout
=
'csr'
)
else
:
if
dim
<
0
:
dim
=
src
.
dim
()
+
dim
if
dim
==
0
and
value
is
not
None
:
col
=
src
.
storage
.
col
()
return
scatter
(
value
,
col
,
dim
=
0
,
dim_size
=
src
.
size
(
0
))
elif
dim
==
0
and
value
is
None
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
return
src
.
storage
.
colcount
().
to
(
src
.
dtype
())
elif
reduce
==
'mean'
or
reduce
==
'min'
or
reduce
==
'max'
:
return
torch
.
ones
(
src
.
size
(
1
),
dtype
=
src
.
dtype
())
else
:
raise
ValueError
elif
dim
==
1
and
value
is
not
None
:
return
segment_csr
(
value
,
src
.
storage
.
rowptr
(),
None
,
reduce
)
elif
dim
==
1
and
value
is
None
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
return
src
.
storage
.
rowcount
().
to
(
src
.
dtype
())
elif
reduce
==
'mean'
or
reduce
==
'min'
or
reduce
==
'max'
:
return
torch
.
ones
(
src
.
size
(
0
),
dtype
=
src
.
dtype
())
else
:
raise
ValueError
elif
dim
>
1
and
value
is
not
None
:
if
reduce
==
'sum'
or
reduce
==
'add'
:
return
value
.
sum
(
dim
=
dim
-
1
)
elif
reduce
==
'mean'
:
return
value
.
mean
(
dim
=
dim
-
1
)
elif
reduce
==
'min'
:
return
value
.
min
(
dim
=
dim
-
1
)[
0
]
elif
reduce
==
'max'
:
return
value
.
max
(
dim
=
dim
-
1
)[
0
]
else
:
raise
ValueError
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
>
0
:
dense_dims
=
dense_dims
[
0
]
if
len
(
dense_dims
)
==
1
else
dense_dims
value
=
getattr
(
torch
,
reduce
)(
value
,
dim
=
dense_dims
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
else
:
raise
ValueError
if
sparse_dims
[
0
]
==
1
and
src
.
has_value
():
out
=
segment_csr
(
value
,
rowptr
)
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
return
out
if
sparse_dims
[
0
]
==
1
and
not
src
.
has_value
():
if
reduce
in
[
'sum'
,
'add'
]:
return
src
.
storage
.
rowcount
.
to
(
torch
.
get_default_dtype
())
elif
reduce
==
'min'
or
'max'
:
# Return an additional `None` arg(min|max) tensor for consistency.
return
torch
.
ones
(
src
.
size
(
0
),
device
=
src
.
device
),
None
else
:
return
torch
.
ones
(
src
.
size
(
0
),
device
=
src
.
device
)
deterministic
=
src
.
storage
.
has_csr2csc
()
or
deterministic
if
sparse_dims
[
0
]
==
0
and
deterministic
and
src
.
has_value
():
csr2csc
=
src
.
storage
.
csr2csc
out
=
segment_csr
(
value
[
csr2csc
],
src
.
storage
.
colptr
)
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
return
out
if
sparse_dims
[
0
]
==
0
and
src
.
has_value
():
reduce
=
'add'
if
reduce
==
'sum'
else
reduce
func
=
getattr
(
torch_scatter
,
f
'scatter_
{
reduce
}
'
)
out
=
func
(
value
,
col
,
dim
=
0
,
dim_size
=
src
.
sparse_size
(
1
))
out
=
out
[
0
]
if
len
(
dense_dims
)
>
0
and
isinstance
(
out
,
tuple
)
else
out
return
out
if
sparse_dims
[
0
]
==
0
and
not
src
.
has_value
():
if
reduce
in
[
'sum'
,
'add'
]:
return
src
.
storage
.
colcount
.
to
(
torch
.
get_default_dtype
())
elif
reduce
==
'min'
or
'max'
:
# Return an additional `None` arg(min|max) tensor for consistency.
return
torch
.
ones
(
src
.
size
(
1
),
device
=
src
.
device
),
None
else
:
return
torch
.
ones
(
src
.
size
(
1
),
device
=
src
.
device
)
@
torch
.
jit
.
script
def
sum
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'sum'
)
def
sum
(
src
,
dim
=
None
,
deterministic
=
False
):
return
reduction
(
src
,
dim
,
reduce
=
'sum'
,
deterministic
=
deterministic
)
@
torch
.
jit
.
script
def
mean
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'mean'
)
def
mean
(
src
,
dim
=
None
,
deterministic
=
False
):
return
reduction
(
src
,
dim
,
reduce
=
'mean'
,
deterministic
=
deterministic
)
@
torch
.
jit
.
script
def
min
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'min'
)
def
min
(
src
,
dim
=
None
,
deterministic
=
False
):
return
reduction
(
src
,
dim
,
reduce
=
'min'
,
deterministic
=
deterministic
)
@
torch
.
jit
.
script
def
max
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'max'
)
def
max
(
src
,
dim
=
None
,
deterministic
=
False
):
return
reduction
(
src
,
dim
,
reduce
=
'max'
,
deterministic
=
deterministic
)
SparseTensor
.
sum
=
lambda
self
,
dim
=
None
:
sum
(
self
,
dim
)
SparseTensor
.
mean
=
lambda
self
,
dim
=
None
:
mean
(
self
,
dim
)
SparseTensor
.
min
=
lambda
self
,
dim
=
None
:
min
(
self
,
dim
)
SparseTensor
.
max
=
lambda
self
,
dim
=
None
:
max
(
self
,
dim
)
torch_sparse/storage.py
View file @
3c6dbfa1
...
...
@@ -304,9 +304,8 @@ class SparseStorage(object):
if
colptr
is
not
None
:
colcount
=
colptr
[
1
:]
-
colptr
[
1
:]
else
:
raise
NotImplementedError
# colcount = scatter_add(torch.ones_like(self._col), self._col,
# dim_size=self._sparse_sizes[1])
colcount
=
scatter_add
(
torch
.
ones_like
(
self
.
_col
),
self
.
_col
,
dim_size
=
self
.
_sparse_sizes
[
1
])
self
.
_colcount
=
colcount
return
colcount
...
...
@@ -355,8 +354,7 @@ class SparseStorage(object):
if
value
is
not
None
:
ptr
=
mask
.
nonzero
().
flatten
()
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
value
.
size
(
0
))])
raise
NotImplementedError
# value = segment_csr(value, ptr, reduce=reduce)
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
...
...
torch_sparse/tensor.py
View file @
3c6dbfa1
...
...
@@ -380,11 +380,6 @@ class SparseTensor(object):
# return matmul(self, other, reduce='sum')
# SparseTensor.reduction = torch_sparse.reduce.reduction
# SparseTensor.sum = torch_sparse.reduce.sum
# SparseTensor.mean = torch_sparse.reduce.mean
# SparseTensor.min = torch_sparse.reduce.min
# SparseTensor.max = torch_sparse.reduce.max
# SparseTensor.matmul = matmul
# Python Bindings #############################################################
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment