Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
592d63d2
Commit
592d63d2
authored
Jan 28, 2020
by
rusty1s
Browse files
repr
parent
f87afd09
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
97 additions
and
91 deletions
+97
-91
test/test_jit.py
test/test_jit.py
+6
-5
torch_sparse/tensor.py
torch_sparse/tensor.py
+91
-86
No files found.
test/test_jit.py
View file @
592d63d2
...
@@ -65,13 +65,14 @@ def test_jit():
...
@@ -65,13 +65,14 @@ def test_jit():
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage
# adj = adj.storage
rowptr
=
torch
.
tensor
([
0
,
3
,
6
,
9
])
rowptr
=
torch
.
tensor
([
0
,
1
,
4
,
7
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
0
,
1
,
2
,
0
,
1
,
2
])
col
=
torch
.
tensor
([
0
,
0
,
1
,
2
,
0
,
1
,
2
])
adj
=
SparseTensor
(
rowptr
=
rowptr
,
col
=
col
)
adj
=
SparseTensor
(
rowptr
=
rowptr
,
col
=
col
)
scipy
=
adj
.
to_scipy
(
layout
=
'csr'
)
# scipy = adj.to_scipy(layout='csr')
mat
=
SparseTensor
.
from_scipy
(
scipy
)
# mat = SparseTensor.from_scipy(scipy)
mat
.
fill_value_
(
2.3
)
print
()
print
(
adj
)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# foo = Foo(mat.storage.rowptr, mat.storage.col)
...
...
torch_sparse/tensor.py
View file @
592d63d2
#
from textwrap import indent
from
textwrap
import
indent
from
typing
import
Optional
,
List
,
Tuple
,
Union
from
typing
import
Optional
,
List
,
Tuple
,
Union
import
torch
import
torch
...
@@ -382,92 +382,29 @@ class SparseTensor(object):
...
@@ -382,92 +382,29 @@ class SparseTensor(object):
return
torch
.
sparse_coo_tensor
(
index
,
value
,
self
.
sizes
())
return
torch
.
sparse_coo_tensor
(
index
,
value
,
self
.
sizes
())
# Standard Operators ######################################################
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def __iadd__(self, other):
# return self.add_(other)
# def __mul__(self, other):
# return self.mul(other)
# def __rmul__(self, other):
# return self.mul(other)
# def __imul__(self, other):
# return self.mul_(other)
# def __matmul__(self, other):
# return matmul(self, other, reduce='sum')
# # Standard Operators ######################################################
# def __getitem__(self, index):
# index = list(index) if isinstance(index, tuple) else [index]
# # More than one `Ellipsis` is not allowed...
# if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
# raise SyntaxError
# dim = 0
# out = self
# while len(index) > 0:
# item = index.pop(0)
# if isinstance(item, int):
# out = out.select(dim, item)
# dim += 1
# elif isinstance(item, slice):
# if item.step is not None:
# raise ValueError('Step parameter not yet supported.')
# start = 0 if item.start is None else item.start
# start = self.size(dim) + start if start < 0 else start
# stop = self.size(dim) if item.stop is None else item.stop
# stop = self.size(dim) + stop if stop < 0 else stop
# out = out.narrow(dim, start, max(stop - start, 0))
# dim += 1
# elif torch.is_tensor(item):
# if item.dtype == torch.bool:
# out = out.masked_select(dim, item)
# dim += 1
# elif item.dtype == torch.long:
# out = out.index_select(dim, item)
# dim += 1
# elif item == Ellipsis:
# if self.dim() - len(index) < dim:
# raise SyntaxError
# dim = self.dim() - len(index)
# else:
# raise SyntaxError
# return out
# def __add__(self, other):
# return self.add(other)
# def __radd__(self, other):
# return self.add(other)
# def __iadd__(self, other):
# return self.add_(other)
# def __mul__(self, other):
# return self.mul(other)
# def __rmul__(self, other):
# return self.mul(other)
# def __imul__(self, other):
# return self.mul_(other)
# def __matmul__(self, other):
# return matmul(self, other, reduce='sum')
# # String Reputation #######################################################
# def __repr__(self):
# i = ' ' * 6
# row, col, value = self.coo()
# infos = []
# infos += [f'row={indent(row.__repr__(), i)[len(i):]}']
# infos += [f'col={indent(col.__repr__(), i)[len(i):]}']
# if self.has_value():
# infos += [f'val={indent(value.__repr__(), i)[len(i):]}']
# infos += [
# f'size={tuple(self.size())}, '
# f'nnz={self.nnz()}, '
# f'density={100 * self.density():.02f}%'
# ]
# infos = ',\n'.join(infos)
# i = ' ' * (len(self.__class__.__name__) + 1)
# return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})'
# Bindings ####################################################################
# Bindings ####################################################################
...
@@ -531,9 +468,77 @@ def to(self, *args, **kwargs):
...
@@ -531,9 +468,77 @@ def to(self, *args, **kwargs):
return
self
return
self
@
torch
.
jit
.
ignore
def
__getitem__
(
self
,
index
):
raise
NotImplementedError
index
=
list
(
index
)
if
isinstance
(
index
,
tuple
)
else
[
index
]
# More than one `Ellipsis` is not allowed...
if
len
([
i
for
i
in
index
if
not
torch
.
is_tensor
(
i
)
and
i
==
...])
>
1
:
raise
SyntaxError
dim
=
0
out
=
self
while
len
(
index
)
>
0
:
item
=
index
.
pop
(
0
)
if
isinstance
(
item
,
int
):
out
=
out
.
select
(
dim
,
item
)
dim
+=
1
elif
isinstance
(
item
,
slice
):
if
item
.
step
is
not
None
:
raise
ValueError
(
'Step parameter not yet supported.'
)
start
=
0
if
item
.
start
is
None
else
item
.
start
start
=
self
.
size
(
dim
)
+
start
if
start
<
0
else
start
stop
=
self
.
size
(
dim
)
if
item
.
stop
is
None
else
item
.
stop
stop
=
self
.
size
(
dim
)
+
stop
if
stop
<
0
else
stop
out
=
out
.
narrow
(
dim
,
start
,
max
(
stop
-
start
,
0
))
dim
+=
1
elif
torch
.
is_tensor
(
item
):
if
item
.
dtype
==
torch
.
bool
:
out
=
out
.
masked_select
(
dim
,
item
)
dim
+=
1
elif
item
.
dtype
==
torch
.
long
:
out
=
out
.
index_select
(
dim
,
item
)
dim
+=
1
elif
item
==
Ellipsis
:
if
self
.
dim
()
-
len
(
index
)
<
dim
:
raise
SyntaxError
dim
=
self
.
dim
()
-
len
(
index
)
else
:
raise
SyntaxError
return
out
@
torch
.
jit
.
ignore
def
__repr__
(
self
):
i
=
' '
*
6
row
,
col
,
value
=
self
.
coo
()
infos
=
[]
infos
+=
[
f
'row=
{
indent
(
row
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
infos
+=
[
f
'col=
{
indent
(
col
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
if
value
is
not
None
:
infos
+=
[
f
'val=
{
indent
(
value
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
infos
+=
[
f
'size=
{
tuple
(
self
.
sizes
())
}
, '
f
'nnz=
{
self
.
nnz
()
}
, '
f
'density=
{
100
*
self
.
density
():.
02
f
}
%'
]
infos
=
',
\n
'
.
join
(
infos
)
i
=
' '
*
(
len
(
self
.
__class__
.
__name__
)
+
1
)
return
f
'
{
self
.
__class__
.
__name__
}
(
{
indent
(
infos
,
i
)[
len
(
i
):]
}
)'
SparseTensor
.
share_memory_
=
share_memory_
SparseTensor
.
share_memory_
=
share_memory_
SparseTensor
.
is_shared
=
is_shared
SparseTensor
.
is_shared
=
is_shared
SparseTensor
.
to
=
to
SparseTensor
.
to
=
to
SparseTensor
.
__getitem__
=
__getitem__
SparseTensor
.
__repr__
=
__repr__
# Scipy Conversions ###########################################################
# Scipy Conversions ###########################################################
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment