Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
fcf88282
Commit
fcf88282
authored
Jan 25, 2020
by
rusty1s
Browse files
tensor fixes
parent
918b1163
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
101 additions
and
83 deletions
+101
-83
cuda/convert_kernel.cu
cuda/convert_kernel.cu
+1
-0
torch_sparse/storage.py
torch_sparse/storage.py
+6
-10
torch_sparse/tensor.py
torch_sparse/tensor.py
+94
-73
No files found.
cuda/convert_kernel.cu
View file @
fcf88282
...
@@ -37,6 +37,7 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
...
@@ -37,6 +37,7 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// TODO: Make more efficient.
if
(
thread_idx
<
numel
)
{
if
(
thread_idx
<
numel
)
{
int64_t
idx
=
ptr_data
[
thread_idx
],
next_idx
=
ptr_data
[
thread_idx
+
1
];
int64_t
idx
=
ptr_data
[
thread_idx
],
next_idx
=
ptr_data
[
thread_idx
+
1
];
for
(
int64_t
i
=
idx
;
i
<
next_idx
;
i
++
)
{
for
(
int64_t
i
=
idx
;
i
<
next_idx
;
i
++
)
{
...
...
torch_sparse/storage.py
View file @
fcf88282
...
@@ -99,7 +99,6 @@ class SparseStorage(object):
...
@@ -99,7 +99,6 @@ class SparseStorage(object):
if
value
is
not
None
:
if
value
is
not
None
:
assert
value
.
device
==
col
.
device
assert
value
.
device
==
col
.
device
assert
value
.
size
(
0
)
==
col
.
size
(
0
)
assert
value
.
size
(
0
)
==
col
.
size
(
0
)
value
=
value
.
contiguous
()
if
rowcount
is
not
None
:
if
rowcount
is
not
None
:
assert
rowcount
.
dtype
==
torch
.
long
assert
rowcount
.
dtype
==
torch
.
long
...
@@ -160,7 +159,7 @@ class SparseStorage(object):
...
@@ -160,7 +159,7 @@ class SparseStorage(object):
def
row
(
self
):
def
row
(
self
):
if
self
.
_row
is
None
:
if
self
.
_row
is
None
:
func
=
convert_cuda
if
self
.
rowptr
.
is_cuda
else
convert_cpu
func
=
convert_cuda
if
self
.
rowptr
.
is_cuda
else
convert_cpu
self
.
_row
=
func
.
ptr2ind
(
self
.
rowptr
,
self
.
nnz
())
self
.
_row
=
func
.
ptr2ind
(
self
.
rowptr
,
self
.
col
.
numel
())
return
self
.
_row
return
self
.
_row
def
has_rowptr
(
self
):
def
has_rowptr
(
self
):
...
@@ -184,9 +183,9 @@ class SparseStorage(object):
...
@@ -184,9 +183,9 @@ class SparseStorage(object):
def
value
(
self
):
def
value
(
self
):
return
self
.
_value
return
self
.
_value
def
set_value_
(
self
,
value
,
dtype
=
None
,
layout
=
None
):
def
set_value_
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
value
=
torch
.
full
((
self
.
nnz
(),
),
dtype
=
dtype
,
value
=
torch
.
full
((
self
.
col
.
numel
(),
),
dtype
=
dtype
,
device
=
self
.
col
.
device
)
device
=
self
.
col
.
device
)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
...
@@ -200,9 +199,9 @@ class SparseStorage(object):
...
@@ -200,9 +199,9 @@ class SparseStorage(object):
self
.
_value
=
value
self
.
_value
=
value
return
self
return
self
def
set_value
(
self
,
value
,
dtype
=
None
,
layout
=
None
):
def
set_value
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
if
isinstance
(
value
,
int
)
or
isinstance
(
value
,
float
):
value
=
torch
.
full
((
self
.
nnz
(),
),
dtype
=
dtype
,
value
=
torch
.
full
((
self
.
col
.
numel
(),
),
dtype
=
dtype
,
device
=
self
.
col
.
device
)
device
=
self
.
col
.
device
)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
...
@@ -224,7 +223,7 @@ class SparseStorage(object):
...
@@ -224,7 +223,7 @@ class SparseStorage(object):
return
self
.
_sparse_size
return
self
.
_sparse_size
def
sparse_resize
(
self
,
*
sizes
):
def
sparse_resize
(
self
,
*
sizes
):
old_sparse_size
,
nnz
=
self
.
sparse_size
,
self
.
nnz
()
old_sparse_size
,
nnz
=
self
.
sparse_size
,
self
.
col
.
numel
()
diff_0
=
sizes
[
0
]
-
old_sparse_size
[
0
]
diff_0
=
sizes
[
0
]
-
old_sparse_size
[
0
]
rowcount
,
rowptr
=
self
.
_rowcount
,
self
.
_rowptr
rowcount
,
rowptr
=
self
.
_rowcount
,
self
.
_rowptr
...
@@ -258,9 +257,6 @@ class SparseStorage(object):
...
@@ -258,9 +257,6 @@ class SparseStorage(object):
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
nnz
(
self
):
return
self
.
col
.
numel
()
def
has_rowcount
(
self
):
def
has_rowcount
(
self
):
return
self
.
_rowcount
is
not
None
return
self
.
_rowcount
is
not
None
...
...
torch_sparse/tensor.py
View file @
fcf88282
...
@@ -13,11 +13,14 @@ from torch_sparse.masked_select import masked_select, masked_select_nnz
...
@@ -13,11 +13,14 @@ from torch_sparse.masked_select import masked_select, masked_select_nnz
import
torch_sparse.reduce
import
torch_sparse.reduce
from
torch_sparse.diag
import
remove_diag
from
torch_sparse.diag
import
remove_diag
from
torch_sparse.matmul
import
matmul
from
torch_sparse.matmul
import
matmul
from
torch_sparse.add
import
add
,
add_
,
add_nnz
,
add_nnz_
class
SparseTensor
(
object
):
class
SparseTensor
(
object
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
def
__init__
(
self
,
row
=
None
,
rowptr
=
None
,
col
=
None
,
value
=
None
,
self
.
storage
=
SparseStorage
(
index
,
value
,
sparse_size
,
sparse_size
=
None
,
is_sorted
=
False
):
self
.
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_size
=
sparse_size
,
is_sorted
=
is_sorted
)
is_sorted
=
is_sorted
)
@
classmethod
@
classmethod
...
@@ -33,14 +36,15 @@ class SparseTensor(object):
...
@@ -33,14 +36,15 @@ class SparseTensor(object):
else
:
else
:
index
=
mat
.
nonzero
()
index
=
mat
.
nonzero
()
index
=
index
.
t
().
contiguous
()
row
,
col
=
index
.
t
().
contiguous
()
value
=
mat
[
index
[
0
],
index
[
1
]]
return
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
mat
[
row
,
col
],
return
SparseTensor
(
index
,
value
,
mat
.
size
()[:
2
],
is_sorted
=
True
)
sparse_size
=
mat
.
size
()[:
2
],
is_sorted
=
True
)
@
classmethod
@
classmethod
def
from_torch_sparse_coo_tensor
(
self
,
mat
,
is_sorted
=
False
):
def
from_torch_sparse_coo_tensor
(
self
,
mat
,
is_sorted
=
False
):
return
SparseTensor
(
mat
.
_indices
(),
mat
.
_values
(),
row
,
col
=
mat
.
_indices
()
mat
.
size
()[:
2
],
is_sorted
=
is_sorted
)
return
SparseTensor
(
row
=
row
,
col
=
col
,
value
=
mat
.
_values
(),
sparse_size
=
mat
.
size
()[:
2
],
is_sorted
=
is_sorted
)
@
classmethod
@
classmethod
def
from_scipy
(
self
,
mat
):
def
from_scipy
(
self
,
mat
):
...
@@ -48,60 +52,52 @@ class SparseTensor(object):
...
@@ -48,60 +52,52 @@ class SparseTensor(object):
if
isinstance
(
mat
,
scipy
.
sparse
.
csc_matrix
):
if
isinstance
(
mat
,
scipy
.
sparse
.
csc_matrix
):
colptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
torch
.
long
)
colptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
torch
.
long
)
mat
=
mat
.
tocsr
()
mat
=
mat
.
tocsr
()
# Pre-sort.
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
torch
.
long
)
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
torch
.
long
)
mat
=
mat
.
tocoo
()
mat
=
mat
.
tocoo
()
row
=
torch
.
from_numpy
(
mat
.
row
).
to
(
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
row
).
to
(
torch
.
long
)
col
=
torch
.
from_numpy
(
mat
.
col
).
to
(
torch
.
long
)
col
=
torch
.
from_numpy
(
mat
.
col
).
to
(
torch
.
long
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
torch
.
from_numpy
(
mat
.
data
)
value
=
torch
.
from_numpy
(
mat
.
data
)
size
=
mat
.
shape
sparse_
size
=
mat
.
shape
[:
2
]
storage
=
SparseStorage
(
index
,
value
,
size
,
rowptr
=
rowptr
,
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
colptr
=
colptr
,
is_sorted
=
True
)
sparse_size
=
sparse_size
,
colptr
=
colptr
,
is_sorted
=
True
)
return
SparseTensor
.
from_storage
(
storage
)
return
SparseTensor
.
from_storage
(
storage
)
@
classmethod
@
classmethod
def
eye
(
self
,
M
,
N
=
None
,
device
=
None
,
dtype
=
None
,
no
_value
=
Fals
e
,
def
eye
(
self
,
M
,
N
=
None
,
device
=
None
,
dtype
=
None
,
has
_value
=
Tru
e
,
fill_cache
=
False
):
fill_cache
=
False
):
N
=
M
if
N
is
None
else
N
N
=
M
if
N
is
None
else
N
index
=
torch
.
empty
((
2
,
min
(
M
,
N
)),
dtype
=
torch
.
long
,
device
=
device
)
row
=
torch
.
arange
(
min
(
M
,
N
),
device
=
device
)
torch
.
arange
(
index
.
size
(
1
),
out
=
index
[
0
])
rowptr
=
torch
.
arange
(
M
+
1
,
device
=
device
)
torch
.
arange
(
index
.
size
(
1
),
out
=
index
[
1
])
if
M
>
N
:
rowptr
[
row
.
size
(
0
)
+
1
:]
=
row
.
size
(
0
)
col
=
row
value
=
None
value
=
None
if
not
no
_value
:
if
has
_value
:
value
=
torch
.
ones
(
index
.
size
(
1
),
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
ones
(
row
.
size
(
0
),
dtype
=
dtype
,
device
=
device
)
rowcount
=
row
ptr
=
colcount
=
colptr
=
csr2csc
=
csc2csr
=
None
rowcount
=
col
ptr
=
colcount
=
csr2csc
=
csc2csr
=
None
if
fill_cache
:
if
fill_cache
:
rowcount
=
index
.
new_ones
(
M
)
rowcount
=
row
.
new_ones
(
M
)
rowptr
=
torch
.
arange
(
M
+
1
,
device
=
device
)
if
M
>
N
:
if
M
>
N
:
rowcount
[
index
.
size
(
1
):]
=
0
rowcount
[
row
.
size
(
0
):]
=
0
rowptr
[
index
.
size
(
1
)
+
1
:]
=
index
.
size
(
1
)
colcount
=
index
.
new_ones
(
N
)
colptr
=
torch
.
arange
(
N
+
1
,
device
=
device
)
colptr
=
torch
.
arange
(
N
+
1
,
device
=
device
)
colcount
=
col
.
new_ones
(
N
)
if
N
>
M
:
if
N
>
M
:
colcount
[
index
.
size
(
1
):]
=
0
colptr
[
col
.
size
(
0
)
+
1
:]
=
col
.
size
(
0
)
colptr
[
index
.
size
(
1
)
+
1
:]
=
index
.
size
(
1
)
colcount
[
col
.
size
(
0
):]
=
0
csr2csc
=
torch
.
arange
(
index
.
size
(
1
),
device
=
device
)
csr2csc
=
csc2csr
=
row
csc2csr
=
torch
.
arange
(
index
.
size
(
1
),
device
=
device
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
sparse_size
=
torch
.
Size
([
M
,
N
]),
index
,
rowcount
=
rowcount
,
colptr
=
colptr
,
value
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
torch
.
Size
([
M
,
N
]),
csc2csr
=
csc2csr
,
is_sorted
=
True
)
rowcount
=
rowcount
,
rowptr
=
rowptr
,
colcount
=
colcount
,
colptr
=
colptr
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
)
return
SparseTensor
.
from_storage
(
storage
)
return
SparseTensor
.
from_storage
(
storage
)
def
__copy__
(
self
):
def
__copy__
(
self
):
...
@@ -118,7 +114,7 @@ class SparseTensor(object):
...
@@ -118,7 +114,7 @@ class SparseTensor(object):
# Formats #################################################################
# Formats #################################################################
def
coo
(
self
):
def
coo
(
self
):
return
self
.
storage
.
index
,
self
.
storage
.
value
return
self
.
storage
.
row
,
self
.
storage
.
col
,
self
.
storage
.
value
def
csr
(
self
):
def
csr
(
self
):
return
self
.
storage
.
rowptr
,
self
.
storage
.
col
,
self
.
storage
.
value
return
self
.
storage
.
rowptr
,
self
.
storage
.
col
,
self
.
storage
.
value
...
@@ -133,15 +129,16 @@ class SparseTensor(object):
...
@@ -133,15 +129,16 @@ class SparseTensor(object):
def
has_value
(
self
):
def
has_value
(
self
):
return
self
.
storage
.
has_value
()
return
self
.
storage
.
has_value
()
def
set_value_
(
self
,
value
,
layout
=
None
):
def
set_value_
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
self
.
storage
.
set_value_
(
value
,
layout
)
self
.
storage
.
set_value_
(
value
,
layout
,
dtype
)
return
self
return
self
def
set_value
(
self
,
value
,
layout
=
None
):
def
set_value
(
self
,
value
,
layout
=
None
,
dtype
=
None
):
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
))
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
,
dtype
))
def
sparse_size
(
self
,
dim
=
None
):
def
sparse_size
(
self
,
dim
=
None
):
return
self
.
storage
.
sparse_size
(
dim
)
sparse_size
=
self
.
storage
.
sparse_size
return
sparse_size
if
dim
is
None
else
sparse_size
[
dim
]
def
sparse_resize
(
self
,
*
sizes
):
def
sparse_resize
(
self
,
*
sizes
):
return
self
.
from_storage
(
self
.
storage
.
sparse_resize
(
*
sizes
))
return
self
.
from_storage
(
self
.
storage
.
sparse_resize
(
*
sizes
))
...
@@ -165,20 +162,20 @@ class SparseTensor(object):
...
@@ -165,20 +162,20 @@ class SparseTensor(object):
# Utility functions #######################################################
# Utility functions #######################################################
def
dim
(
self
):
return
len
(
self
.
size
())
def
size
(
self
,
dim
=
None
):
def
size
(
self
,
dim
=
None
):
size
=
self
.
sparse_size
()
size
=
self
.
sparse_size
()
size
+=
self
.
storage
.
value
.
size
()[
1
:]
if
self
.
has_value
()
else
()
size
+=
self
.
storage
.
value
.
size
()[
1
:]
if
self
.
has_value
()
else
()
return
size
if
dim
is
None
else
size
[
dim
]
return
size
if
dim
is
None
else
size
[
dim
]
def
dim
(
self
):
return
len
(
self
.
size
())
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
return
self
.
size
()
return
self
.
size
()
def
nnz
(
self
):
def
nnz
(
self
):
return
self
.
storage
.
nnz
()
return
self
.
storage
.
col
.
numel
()
def
density
(
self
):
def
density
(
self
):
return
self
.
nnz
()
/
(
self
.
sparse_size
(
0
)
*
self
.
sparse_size
(
1
))
return
self
.
nnz
()
/
(
self
.
sparse_size
(
0
)
*
self
.
sparse_size
(
1
))
...
@@ -202,11 +199,16 @@ class SparseTensor(object):
...
@@ -202,11 +199,16 @@ class SparseTensor(object):
if
not
self
.
is_quadratic
:
if
not
self
.
is_quadratic
:
return
False
return
False
rowptr
,
col
,
val1
=
self
.
csr
()
rowptr
,
col
,
value1
=
self
.
csr
()
colptr
,
row
,
val2
=
self
.
csc
()
colptr
,
row
,
value2
=
self
.
csc
()
index_sym
=
(
rowptr
==
colptr
).
all
()
and
(
col
==
row
).
all
()
value_sym
=
(
val1
==
val2
).
all
().
item
()
if
self
.
has_value
()
else
True
if
(
rowptr
!=
colptr
).
any
()
or
(
col
!=
row
).
any
():
return
index_sym
.
item
()
and
value_sym
return
False
if
not
self
.
has_value
():
return
True
return
(
value1
==
value2
).
all
().
item
()
def
detach_
(
self
):
def
detach_
(
self
):
self
.
storage
.
apply_
(
lambda
x
:
x
.
detach_
())
self
.
storage
.
apply_
(
lambda
x
:
x
.
detach_
())
...
@@ -219,9 +221,13 @@ class SparseTensor(object):
...
@@ -219,9 +221,13 @@ class SparseTensor(object):
def
requires_grad
(
self
):
def
requires_grad
(
self
):
return
self
.
storage
.
value
.
requires_grad
if
self
.
has_value
()
else
False
return
self
.
storage
.
value
.
requires_grad
if
self
.
has_value
()
else
False
def
requires_grad_
(
self
,
requires_grad
=
True
):
def
requires_grad_
(
self
,
requires_grad
=
True
,
dtype
=
None
):
if
requires_grad
and
not
self
.
has_value
():
self
.
storage
.
set_value_
(
1
,
dtype
=
dtype
)
if
self
.
has_value
():
if
self
.
has_value
():
self
.
storage
.
value
.
requires_grad_
(
requires_grad
)
self
.
storage
.
value
.
requires_grad_
(
requires_grad
)
return
self
return
self
def
pin_memory
(
self
):
def
pin_memory
(
self
):
...
@@ -239,7 +245,7 @@ class SparseTensor(object):
...
@@ -239,7 +245,7 @@ class SparseTensor(object):
@
property
@
property
def
device
(
self
):
def
device
(
self
):
return
self
.
storage
.
index
.
device
return
self
.
storage
.
col
.
device
def
cpu
(
self
):
def
cpu
(
self
):
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
cpu
()))
return
self
.
from_storage
(
self
.
storage
.
apply
(
lambda
x
:
x
.
cpu
()))
...
@@ -251,7 +257,7 @@ class SparseTensor(object):
...
@@ -251,7 +257,7 @@ class SparseTensor(object):
@
property
@
property
def
is_cuda
(
self
):
def
is_cuda
(
self
):
return
self
.
storage
.
index
.
is_cuda
return
self
.
storage
.
col
.
is_cuda
@
property
@
property
def
dtype
(
self
):
def
dtype
(
self
):
...
@@ -296,7 +302,7 @@ class SparseTensor(object):
...
@@ -296,7 +302,7 @@ class SparseTensor(object):
if
len
(
args
)
>
0
or
len
(
kwargs
)
>
0
:
if
len
(
args
)
>
0
or
len
(
kwargs
)
>
0
:
storage
=
storage
.
apply_value
(
lambda
x
:
x
.
type
(
*
args
,
**
kwargs
))
storage
=
storage
.
apply_value
(
lambda
x
:
x
.
type
(
*
args
,
**
kwargs
))
if
storage
==
self
.
storage
:
# Nothing changed...
if
storage
==
self
.
storage
:
# Nothing
has been
changed...
return
self
return
self
else
:
else
:
return
self
.
from_storage
(
storage
)
return
self
.
from_storage
(
storage
)
...
@@ -335,19 +341,21 @@ class SparseTensor(object):
...
@@ -335,19 +341,21 @@ class SparseTensor(object):
def
to_dense
(
self
,
dtype
=
None
):
def
to_dense
(
self
,
dtype
=
None
):
dtype
=
dtype
or
self
.
dtype
dtype
=
dtype
or
self
.
dtype
(
row
,
col
)
,
value
=
self
.
coo
()
row
,
col
,
value
=
self
.
coo
()
mat
=
torch
.
zeros
(
self
.
size
(),
dtype
=
dtype
,
device
=
self
.
device
)
mat
=
torch
.
zeros
(
self
.
size
(),
dtype
=
dtype
,
device
=
self
.
device
)
mat
[
row
,
col
]
=
value
if
self
.
has_value
()
else
1
mat
[
row
,
col
]
=
value
if
self
.
has_value
()
else
1
return
mat
return
mat
def
to_torch_sparse_coo_tensor
(
self
,
dtype
=
None
,
requires_grad
=
False
):
def
to_torch_sparse_coo_tensor
(
self
,
dtype
=
None
,
requires_grad
=
False
):
index
,
value
=
self
.
coo
()
row
,
col
,
value
=
self
.
coo
()
return
torch
.
sparse_coo_tensor
(
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
index
,
value
if
self
.
has_value
()
else
torch
.
ones
(
if
value
is
None
:
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
),
self
.
size
(),
value
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
)
device
=
self
.
device
,
requires_grad
=
requires_grad
)
return
torch
.
sparse_coo_tensor
(
index
,
value
,
self
.
size
(),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
None
):
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
"csr"
):
assert
self
.
dim
()
==
2
assert
self
.
dim
()
==
2
layout
=
get_layout
(
layout
)
layout
=
get_layout
(
layout
)
...
@@ -355,7 +363,7 @@ class SparseTensor(object):
...
@@ -355,7 +363,7 @@ class SparseTensor(object):
ones
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
).
numpy
()
ones
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
).
numpy
()
if
layout
==
'coo'
:
if
layout
==
'coo'
:
(
row
,
col
)
,
value
=
self
.
coo
()
row
,
col
,
value
=
self
.
coo
()
row
=
row
.
detach
().
cpu
().
numpy
()
row
=
row
.
detach
().
cpu
().
numpy
()
col
=
col
.
detach
().
cpu
().
numpy
()
col
=
col
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
...
@@ -379,7 +387,7 @@ class SparseTensor(object):
...
@@ -379,7 +387,7 @@ class SparseTensor(object):
index
=
list
(
index
)
if
isinstance
(
index
,
tuple
)
else
[
index
]
index
=
list
(
index
)
if
isinstance
(
index
,
tuple
)
else
[
index
]
# More than one `Ellipsis` is not allowed...
# More than one `Ellipsis` is not allowed...
if
len
([
i
for
i
in
index
if
not
torch
.
is_tensor
(
i
)
and
i
==
...])
>
1
:
if
len
([
i
for
i
in
index
if
not
torch
.
is_tensor
(
i
)
and
i
==
...])
>
1
:
raise
SyntaxError
()
raise
SyntaxError
dim
=
0
dim
=
0
out
=
self
out
=
self
...
@@ -416,6 +424,15 @@ class SparseTensor(object):
...
@@ -416,6 +424,15 @@ class SparseTensor(object):
return
out
return
out
def
__add__
(
self
,
other
):
return
self
.
add
(
other
)
def
__radd__
(
self
,
other
):
return
self
.
add
(
other
)
def
__iadd__
(
self
,
other
):
return
self
.
add_
(
other
)
def
__matmul__
(
a
,
b
):
def
__matmul__
(
a
,
b
):
return
matmul
(
a
,
b
,
reduce
=
'sum'
)
return
matmul
(
a
,
b
,
reduce
=
'sum'
)
...
@@ -423,8 +440,10 @@ class SparseTensor(object):
...
@@ -423,8 +440,10 @@ class SparseTensor(object):
def
__repr__
(
self
):
def
__repr__
(
self
):
i
=
' '
*
6
i
=
' '
*
6
index
,
value
=
self
.
coo
()
row
,
col
,
value
=
self
.
coo
()
infos
=
[
f
'index=
{
indent
(
index
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
infos
=
[]
infos
+=
[
f
'row=
{
indent
(
row
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
infos
+=
[
f
'col=
{
indent
(
col
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
if
self
.
has_value
():
if
self
.
has_value
():
infos
+=
[
f
'value=
{
indent
(
value
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
infos
+=
[
f
'value=
{
indent
(
value
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
...
@@ -456,8 +475,10 @@ SparseTensor.min = torch_sparse.reduce.min
...
@@ -456,8 +475,10 @@ SparseTensor.min = torch_sparse.reduce.min
SparseTensor
.
max
=
torch_sparse
.
reduce
.
max
SparseTensor
.
max
=
torch_sparse
.
reduce
.
max
SparseTensor
.
remove_diag
=
remove_diag
SparseTensor
.
remove_diag
=
remove_diag
SparseTensor
.
matmul
=
matmul
SparseTensor
.
matmul
=
matmul
# SparseTensor.add = add
SparseTensor
.
add
=
add
# SparseTensor.add_nnz = add_nnz
SparseTensor
.
add_
=
add_
SparseTensor
.
add_nnz
=
add_nnz
SparseTensor
.
add_nnz_
=
add_nnz_
# def __add__(self, other):
# def __add__(self, other):
# return self.add(other)
# return self.add(other)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment