Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
b50a4861
Commit
b50a4861
authored
Jan 25, 2020
by
rusty1s
Browse files
typos
parent
fcf88282
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
31 deletions
+16
-31
torch_sparse/reduce.py
torch_sparse/reduce.py
+5
-5
torch_sparse/tensor.py
torch_sparse/tensor.py
+11
-26
No files found.
torch_sparse/reduce.py
View file @
b50a4861
...
@@ -4,13 +4,13 @@ from torch_scatter import segment_csr
...
@@ -4,13 +4,13 @@ from torch_scatter import segment_csr
def
reduction
(
src
,
dim
=
None
,
reduce
=
'sum'
,
deterministic
=
False
):
def
reduction
(
src
,
dim
=
None
,
reduce
=
'sum'
,
deterministic
=
False
):
assert
reduce
in
[
'sum'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
if
dim
is
None
and
src
.
has_value
():
if
dim
is
None
and
src
.
has_value
():
return
getattr
(
torch
,
reduce
)(
src
.
storage
.
value
)
return
getattr
(
torch
,
reduce
)(
src
.
storage
.
value
)
if
dim
is
None
and
not
src
.
has_value
():
if
dim
is
None
and
not
src
.
has_value
():
value
=
src
.
nnz
()
if
reduce
==
'sum'
else
1
value
=
src
.
nnz
()
if
reduce
in
[
'sum'
,
'add'
]
else
1
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
dims
=
[
dim
]
if
isinstance
(
dim
,
int
)
else
dim
dims
=
[
dim
]
if
isinstance
(
dim
,
int
)
else
dim
...
@@ -26,7 +26,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
...
@@ -26,7 +26,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
return
getattr
(
torch
,
reduce
)(
value
,
dim
=
(
0
,
)
+
dense_dims
)
return
getattr
(
torch
,
reduce
)(
value
,
dim
=
(
0
,
)
+
dense_dims
)
if
len
(
sparse_dims
)
==
2
and
not
src
.
has_value
():
if
len
(
sparse_dims
)
==
2
and
not
src
.
has_value
():
value
=
src
.
nnz
()
if
reduce
==
'sum'
else
1
value
=
src
.
nnz
()
if
reduce
in
[
'sum'
,
'add'
]
else
1
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
return
torch
.
tensor
(
value
,
device
=
src
.
device
)
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
==
0
:
# src.has_value()
if
len
(
dense_dims
)
>
0
and
len
(
sparse_dims
)
==
0
:
# src.has_value()
...
@@ -47,7 +47,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
...
@@ -47,7 +47,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
return
out
return
out
if
sparse_dims
[
0
]
==
1
and
not
src
.
has_value
():
if
sparse_dims
[
0
]
==
1
and
not
src
.
has_value
():
if
reduce
==
'sum'
:
if
reduce
in
[
'sum'
,
'add'
]
:
return
src
.
storage
.
rowcount
.
to
(
torch
.
get_default_dtype
())
return
src
.
storage
.
rowcount
.
to
(
torch
.
get_default_dtype
())
elif
reduce
==
'min'
or
'max'
:
elif
reduce
==
'min'
or
'max'
:
# Return an additional `None` arg(min|max) tensor for consistency.
# Return an additional `None` arg(min|max) tensor for consistency.
...
@@ -71,7 +71,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
...
@@ -71,7 +71,7 @@ def reduction(src, dim=None, reduce='sum', deterministic=False):
return
out
return
out
if
sparse_dims
[
0
]
==
0
and
not
src
.
has_value
():
if
sparse_dims
[
0
]
==
0
and
not
src
.
has_value
():
if
reduce
==
'sum'
:
if
reduce
in
[
'sum'
,
'add'
]
:
return
src
.
storage
.
colcount
.
to
(
torch
.
get_default_dtype
())
return
src
.
storage
.
colcount
.
to
(
torch
.
get_default_dtype
())
elif
reduce
==
'min'
or
'max'
:
elif
reduce
==
'min'
or
'max'
:
# Return an additional `None` arg(min|max) tensor for consistency.
# Return an additional `None` arg(min|max) tensor for consistency.
...
...
torch_sparse/tensor.py
View file @
b50a4861
...
@@ -433,8 +433,17 @@ class SparseTensor(object):
...
@@ -433,8 +433,17 @@ class SparseTensor(object):
def
__iadd__
(
self
,
other
):
def
__iadd__
(
self
,
other
):
return
self
.
add_
(
other
)
return
self
.
add_
(
other
)
def
__matmul__
(
a
,
b
):
def
__mul__
(
self
,
other
):
return
matmul
(
a
,
b
,
reduce
=
'sum'
)
return
self
.
mul
(
other
)
def
__rmul__
(
self
,
other
):
return
self
.
mul
(
other
)
def
__imul__
(
self
,
other
):
return
self
.
mul_
(
other
)
def
__matmul__
(
self
,
other
):
return
matmul
(
self
,
other
,
reduce
=
'sum'
)
# String Reputation #######################################################
# String Reputation #######################################################
...
@@ -479,27 +488,3 @@ SparseTensor.add = add
...
@@ -479,27 +488,3 @@ SparseTensor.add = add
SparseTensor
.
add_
=
add_
SparseTensor
.
add_
=
add_
SparseTensor
.
add_nnz
=
add_nnz
SparseTensor
.
add_nnz
=
add_nnz
SparseTensor
.
add_nnz_
=
add_nnz_
SparseTensor
.
add_nnz_
=
add_nnz_
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def sub(self, layout=None):
# raise NotImplementedError
# def sub_(self, layout=None):
# raise NotImplementedError
# def mul(self, layout=None):
# raise NotImplementedError
# def mul_(self, layout=None):
# raise NotImplementedError
# def div(self, layout=None):
# raise NotImplementedError
# def div_(self, layout=None):
# raise NotImplementedError
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment