Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
34b25b3c
Commit
34b25b3c
authored
Jan 29, 2020
by
rusty1s
Browse files
fixes
parent
4a68dd60
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
30 deletions
+22
-30
test/test_jit.py
test/test_jit.py
+4
-1
torch_sparse/add.py
torch_sparse/add.py
+7
-4
torch_sparse/mul.py
torch_sparse/mul.py
+7
-4
torch_sparse/tensor.py
torch_sparse/tensor.py
+4
-21
No files found.
test/test_jit.py
View file @
34b25b3c
...
@@ -71,10 +71,13 @@ def test_jit():
...
@@ -71,10 +71,13 @@ def test_jit():
# scipy = adj.to_scipy(layout='csr')
# scipy = adj.to_scipy(layout='csr')
# mat = SparseTensor.from_scipy(scipy)
# mat = SparseTensor.from_scipy(scipy)
print
()
print
()
print
(
adj
)
# adj = t(adj)
# adj = t(adj)
adj
=
adj
.
t
()
adj
=
adj
.
t
()
adj
=
adj
.
remove_diag
(
k
=
0
)
print
(
adj
.
to_dense
())
adj
=
adj
+
torch
.
tensor
([
1
,
2
,
3
]).
view
(
1
,
3
)
print
(
adj
)
print
(
adj
)
print
(
adj
.
to_dense
())
# print(adj.t)
# print(adj.t)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
...
...
torch_sparse/add.py
View file @
34b25b3c
...
@@ -20,7 +20,7 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
...
@@ -20,7 +20,7 @@ def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
f
'
{
other
.
size
()
}
.'
)
f
'
{
other
.
size
()
}
.'
)
if
value
is
not
None
:
if
value
is
not
None
:
value
=
other
.
add_
(
value
)
value
=
other
.
to
(
value
.
dtype
).
add_
(
value
)
else
:
else
:
value
=
other
.
add_
(
1
)
value
=
other
.
add_
(
1
)
return
src
.
set_value
(
value
,
layout
=
'coo'
)
return
src
.
set_value
(
value
,
layout
=
'coo'
)
...
@@ -41,7 +41,7 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
...
@@ -41,7 +41,7 @@ def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
f
'
{
other
.
size
()
}
.'
)
f
'
{
other
.
size
()
}
.'
)
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
add_
(
other
)
value
=
value
.
add_
(
other
.
to
(
value
.
dtype
)
)
else
:
else
:
value
=
other
.
add_
(
1
)
value
=
other
.
add_
(
1
)
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
...
@@ -52,7 +52,7 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
...
@@ -52,7 +52,7 @@ def add_nnz(src: SparseTensor, other: torch.Tensor,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
()
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
add
(
other
)
value
=
value
.
add
(
other
.
to
(
value
.
dtype
)
)
else
:
else
:
value
=
other
.
add
(
1
)
value
=
other
.
add
(
1
)
return
src
.
set_value
(
value
,
layout
=
layout
)
return
src
.
set_value
(
value
,
layout
=
layout
)
...
@@ -63,7 +63,7 @@ def add_nnz_(src: SparseTensor, other: torch.Tensor,
...
@@ -63,7 +63,7 @@ def add_nnz_(src: SparseTensor, other: torch.Tensor,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
()
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
add_
(
other
)
value
=
value
.
add_
(
other
.
to
(
value
.
dtype
)
)
else
:
else
:
value
=
other
.
add
(
1
)
value
=
other
.
add
(
1
)
return
src
.
set_value_
(
value
,
layout
=
layout
)
return
src
.
set_value_
(
value
,
layout
=
layout
)
...
@@ -75,3 +75,6 @@ SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz(
...
@@ -75,3 +75,6 @@ SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz(
self
,
other
,
layout
)
self
,
other
,
layout
)
SparseTensor
.
add_nnz_
=
lambda
self
,
other
,
layout
=
None
:
add_nnz_
(
SparseTensor
.
add_nnz_
=
lambda
self
,
other
,
layout
=
None
:
add_nnz_
(
self
,
other
,
layout
)
self
,
other
,
layout
)
SparseTensor
.
__add__
=
SparseTensor
.
add
SparseTensor
.
__radd__
=
SparseTensor
.
add
SparseTensor
.
__iadd__
=
SparseTensor
.
add_
torch_sparse/mul.py
View file @
34b25b3c
...
@@ -20,7 +20,7 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
...
@@ -20,7 +20,7 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
f
'
{
other
.
size
()
}
.'
)
f
'
{
other
.
size
()
}
.'
)
if
value
is
not
None
:
if
value
is
not
None
:
value
=
other
.
mul_
(
value
)
value
=
other
.
to
(
value
.
dtype
).
mul_
(
value
)
else
:
else
:
value
=
other
value
=
other
return
src
.
set_value
(
value
,
layout
=
'coo'
)
return
src
.
set_value
(
value
,
layout
=
'coo'
)
...
@@ -41,7 +41,7 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
...
@@ -41,7 +41,7 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
f
'
{
other
.
size
()
}
.'
)
f
'
{
other
.
size
()
}
.'
)
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
mul_
(
other
)
value
=
value
.
mul_
(
other
.
to
(
value
.
dtype
)
)
else
:
else
:
value
=
other
value
=
other
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
...
@@ -52,7 +52,7 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
...
@@ -52,7 +52,7 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
()
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
mul
(
other
)
value
=
value
.
mul
(
other
.
to
(
value
.
dtype
)
)
else
:
else
:
value
=
other
value
=
other
return
src
.
set_value
(
value
,
layout
=
layout
)
return
src
.
set_value
(
value
,
layout
=
layout
)
...
@@ -63,7 +63,7 @@ def mul_nnz_(src: SparseTensor, other: torch.Tensor,
...
@@ -63,7 +63,7 @@ def mul_nnz_(src: SparseTensor, other: torch.Tensor,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
value
=
src
.
storage
.
value
()
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
mul_
(
other
)
value
=
value
.
mul_
(
other
.
to
(
value
.
dtype
)
)
else
:
else
:
value
=
other
value
=
other
return
src
.
set_value_
(
value
,
layout
=
layout
)
return
src
.
set_value_
(
value
,
layout
=
layout
)
...
@@ -75,3 +75,6 @@ SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz(
...
@@ -75,3 +75,6 @@ SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz(
self
,
other
,
layout
)
self
,
other
,
layout
)
SparseTensor
.
mul_nnz_
=
lambda
self
,
other
,
layout
=
None
:
mul_nnz_
(
SparseTensor
.
mul_nnz_
=
lambda
self
,
other
,
layout
=
None
:
mul_nnz_
(
self
,
other
,
layout
)
self
,
other
,
layout
)
SparseTensor
.
__mul__
=
SparseTensor
.
mul
SparseTensor
.
__rmul__
=
SparseTensor
.
mul
SparseTensor
.
__imul__
=
SparseTensor
.
mul_
torch_sparse/tensor.py
View file @
34b25b3c
...
@@ -345,7 +345,10 @@ class SparseTensor(object):
...
@@ -345,7 +345,10 @@ class SparseTensor(object):
def
to_dense
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
def
to_dense
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
row
,
col
,
value
=
self
.
coo
()
row
,
col
,
value
=
self
.
coo
()
if
options
is
not
None
:
if
value
is
not
None
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
value
.
dtype
,
device
=
self
.
device
())
elif
options
is
not
None
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
options
.
dtype
,
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
options
.
dtype
,
device
=
self
.
device
())
device
=
self
.
device
())
else
:
else
:
...
@@ -373,24 +376,6 @@ class SparseTensor(object):
...
@@ -373,24 +376,6 @@ class SparseTensor(object):
# Standard Operators ######################################################
# Standard Operators ######################################################
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def __iadd__(self, other):
# return self.add_(other)
# def __mul__(self, other):
# return self.mul(other)
# def __rmul__(self, other):
# return self.mul(other)
# def __imul__(self, other):
# return self.mul_(other)
# def __matmul__(self, other):
# def __matmul__(self, other):
# return matmul(self, other, reduce='sum')
# return matmul(self, other, reduce='sum')
...
@@ -400,8 +385,6 @@ class SparseTensor(object):
...
@@ -400,8 +385,6 @@ class SparseTensor(object):
# SparseTensor.mean = torch_sparse.reduce.mean
# SparseTensor.mean = torch_sparse.reduce.mean
# SparseTensor.min = torch_sparse.reduce.min
# SparseTensor.min = torch_sparse.reduce.min
# SparseTensor.max = torch_sparse.reduce.max
# SparseTensor.max = torch_sparse.reduce.max
# SparseTensor.remove_diag = remove_diag
# SparseTensor.set_diag = set_diag
# SparseTensor.matmul = matmul
# SparseTensor.matmul = matmul
# Python Bindings #############################################################
# Python Bindings #############################################################
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment