Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
8819288a
Commit
8819288a
authored
Jan 26, 2020
by
rusty1s
Browse files
overload fix
parent
73146b9b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
147 additions
and
8 deletions
+147
-8
test/test_overload.py
test/test_overload.py
+20
-0
torch_sparse/add.py
torch_sparse/add.py
+8
-4
torch_sparse/mul.py
torch_sparse/mul.py
+96
-0
torch_sparse/tensor.py
torch_sparse/tensor.py
+23
-4
No files found.
test/test_overload.py
0 → 100644
View file @
8819288a
import
torch
from
torch_sparse.tensor
import
SparseTensor
def
test_overload
():
row
=
torch
.
tensor
([
0
,
1
,
1
,
2
,
2
])
col
=
torch
.
tensor
([
1
,
0
,
2
,
1
,
2
])
mat
=
SparseTensor
(
row
=
row
,
col
=
col
)
other
=
torch
.
tensor
([
1
,
2
,
3
]).
view
(
3
,
1
)
other
+
mat
mat
+
other
other
*
mat
mat
*
other
other
=
torch
.
tensor
([
1
,
2
,
3
]).
view
(
1
,
3
)
other
+
mat
mat
+
other
other
*
mat
mat
*
other
torch_sparse/add.py
View file @
8819288a
...
...
@@ -2,6 +2,10 @@ import torch
from
torch_scatter
import
gather_csr
def
is_scalar
(
other
):
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
def
sparse_add
(
matA
,
matB
):
nnzA
,
nnzB
=
matA
.
nnz
(),
matB
.
nnz
()
valA
=
torch
.
full
((
nnzA
,
),
1
,
dtype
=
torch
.
uint8
,
device
=
matA
.
device
)
...
...
@@ -38,7 +42,7 @@ def sparse_add(matA, matB):
def
add
(
src
,
other
):
if
is
instance
(
other
,
int
)
or
isinstance
(
other
,
float
):
if
is
_scalar
(
other
):
return
add_nnz
(
src
,
other
)
elif
torch
.
is_tensor
(
other
):
...
...
@@ -65,7 +69,7 @@ def add(src, other):
def
add_
(
src
,
other
):
if
is
instance
(
other
,
int
)
or
isinstance
(
other
,
float
):
if
is
_scalar
(
other
):
return
add_nnz_
(
src
,
other
)
elif
torch
.
is_tensor
(
other
):
...
...
@@ -98,7 +102,7 @@ def add_(src, other):
def
add_nnz
(
src
,
other
,
layout
=
None
):
if
is
instance
(
other
,
int
)
or
isinstance
(
other
,
float
):
if
is
_scalar
(
other
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
+
other
else
:
...
...
@@ -117,7 +121,7 @@ def add_nnz(src, other, layout=None):
def
add_nnz_
(
src
,
other
,
layout
=
None
):
if
is
instance
(
other
,
int
)
or
isinstance
(
other
,
float
):
if
is
_scalar
(
other
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
add_
(
other
)
else
:
...
...
torch_sparse/mul.py
0 → 100644
View file @
8819288a
import
torch
from
torch_scatter
import
gather_csr
def
is_scalar
(
other
):
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
def
mul
(
src
,
other
):
if
is_scalar
(
other
):
return
mul_nnz
(
src
,
other
)
elif
torch
.
is_tensor
(
other
):
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
if
src
.
has_value
():
value
=
other
.
mul_
(
src
.
storage
.
value
)
else
:
value
=
other
return
src
.
set_value
(
value
,
layout
=
'csr'
)
if
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
if
src
.
has_value
():
value
=
other
.
mul_
(
src
.
storage
.
value
)
else
:
value
=
other
return
src
.
set_value
(
value
,
layout
=
'coo'
)
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
'
{
other
.
size
()
}
.'
)
elif
isinstance
(
other
,
src
.
__class__
):
raise
NotImplementedError
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.'
)
def
mul_
(
src
,
other
):
if
is_scalar
(
other
):
return
mul_nnz_
(
src
,
other
)
elif
torch
.
is_tensor
(
other
):
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
other
=
gather_csr
(
other
.
squeeze
(
1
),
rowptr
)
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
mul_
(
other
)
else
:
value
=
other
return
src
.
set_value_
(
value
,
layout
=
'csr'
)
if
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
mul_
(
other
)
else
:
value
=
other
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
'
{
other
.
size
()
}
.'
)
elif
isinstance
(
other
,
src
.
__class__
):
raise
NotImplementedError
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.'
)
def
mul_nnz
(
src
,
other
,
layout
=
None
):
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
*
other
else
:
value
=
other
return
src
.
set_value
(
value
,
layout
=
'coo'
)
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.'
)
def
mul_nnz_
(
src
,
other
,
layout
=
None
):
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
mul_
(
other
)
else
:
value
=
other
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.'
)
torch_sparse/tensor.py
View file @
8819288a
...
...
@@ -14,6 +14,7 @@ import torch_sparse.reduce
from
torch_sparse.diag
import
remove_diag
,
set_diag
from
torch_sparse.matmul
import
matmul
from
torch_sparse.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
torch_sparse.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
class
SparseTensor
(
object
):
...
...
@@ -455,7 +456,7 @@ class SparseTensor(object):
infos
+=
[
f
'col=
{
indent
(
col
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
if
self
.
has_value
():
infos
+=
[
f
'val
ue
=
{
indent
(
value
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
infos
+=
[
f
'val=
{
indent
(
value
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
infos
+=
[
f
'size=
{
tuple
(
self
.
size
())
}
, '
...
...
@@ -482,10 +483,28 @@ SparseTensor.sum = torch_sparse.reduce.sum
SparseTensor
.
mean
=
torch_sparse
.
reduce
.
mean
SparseTensor
.
min
=
torch_sparse
.
reduce
.
min
SparseTensor
.
max
=
torch_sparse
.
reduce
.
max
SparseTensor
.
remove_diag
=
remove_diag
#TODO
SparseTensor
.
set_diag
=
set_diag
#TODO
SparseTensor
.
matmul
=
matmul
# TODO
SparseTensor
.
remove_diag
=
remove_diag
SparseTensor
.
set_diag
=
set_diag
SparseTensor
.
matmul
=
matmul
SparseTensor
.
add
=
add
SparseTensor
.
add_
=
add_
SparseTensor
.
add_nnz
=
add_nnz
SparseTensor
.
add_nnz_
=
add_nnz_
SparseTensor
.
mul
=
mul
SparseTensor
.
mul_
=
mul_
SparseTensor
.
mul_nnz
=
mul_nnz
SparseTensor
.
mul_nnz_
=
mul_nnz_
# Fix for PyTorch<=1.3 (https://github.com/pytorch/pytorch/pull/31769):
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
<=
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
4
):
def
add
(
self
,
other
):
return
self
.
add
(
other
)
if
torch
.
is_tensor
(
other
)
else
NotImplemented
def
mul
(
self
,
other
):
return
self
.
mul
(
other
)
if
torch
.
is_tensor
(
other
)
else
NotImplemented
torch
.
Tensor
.
__add__
=
add
torch
.
Tensor
.
__mul__
=
add
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment