Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
bfb571cb
Unverified
Commit
bfb571cb
authored
Feb 23, 2020
by
Matthias Fey
Committed by
GitHub
Feb 23, 2020
Browse files
Merge pull request #40 from rusty1s/metis
[WIP] Partition
parents
e78637ea
eee47eee
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
91 additions
and
144 deletions
+91
-144
torch_sparse/mul.py
torch_sparse/mul.py
+0
-4
torch_sparse/narrow.py
torch_sparse/narrow.py
+2
-5
torch_sparse/permute.py
torch_sparse/permute.py
+10
-0
torch_sparse/reduce.py
torch_sparse/reduce.py
+0
-5
torch_sparse/select.py
torch_sparse/select.py
+0
-2
torch_sparse/spspmm.py
torch_sparse/spspmm.py
+2
-2
torch_sparse/storage.py
torch_sparse/storage.py
+59
-110
torch_sparse/tensor.py
torch_sparse/tensor.py
+16
-13
torch_sparse/transpose.py
torch_sparse/transpose.py
+2
-3
No files found.
torch_sparse/mul.py
View file @
bfb571cb
...
...
@@ -5,7 +5,6 @@ from torch_scatter import gather_csr
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
mul
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
...
@@ -25,7 +24,6 @@ def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
mul_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
)
->
SparseTensor
:
rowptr
,
col
,
value
=
src
.
csr
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
...
...
@@ -45,7 +43,6 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
@
torch
.
jit
.
script
def
mul_nnz
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
...
...
@@ -56,7 +53,6 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
return
src
.
set_value
(
value
,
layout
=
layout
)
@
torch
.
jit
.
script
def
mul_nnz_
(
src
:
SparseTensor
,
other
:
torch
.
Tensor
,
layout
:
Optional
[
str
]
=
None
)
->
SparseTensor
:
value
=
src
.
storage
.
value
()
...
...
torch_sparse/narrow.py
View file @
bfb571cb
from
typing
import
Tuple
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
narrow
(
src
:
SparseTensor
,
dim
:
int
,
start
:
int
,
length
:
int
)
->
SparseTensor
:
if
dim
<
0
:
...
...
@@ -31,7 +29,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if
value
is
not
None
:
value
=
value
.
narrow
(
0
,
row_start
,
row_length
)
sparse_sizes
=
torch
.
Size
([
length
,
src
.
sparse_size
(
1
)
]
)
sparse_sizes
=
(
length
,
src
.
sparse_size
(
1
))
rowcount
=
src
.
storage
.
_rowcount
if
rowcount
is
not
None
:
...
...
@@ -54,7 +52,7 @@ def narrow(src: SparseTensor, dim: int, start: int,
if
value
is
not
None
:
value
=
value
[
mask
]
sparse_sizes
=
torch
.
Size
([
src
.
sparse_size
(
0
),
length
]
)
sparse_sizes
=
(
src
.
sparse_size
(
0
),
length
)
colptr
=
src
.
storage
.
_colptr
if
colptr
is
not
None
:
...
...
@@ -80,7 +78,6 @@ def narrow(src: SparseTensor, dim: int, start: int,
raise
ValueError
@
torch
.
jit
.
script
def
__narrow_diag__
(
src
:
SparseTensor
,
start
:
Tuple
[
int
,
int
],
length
:
Tuple
[
int
,
int
])
->
SparseTensor
:
# This function builds the inverse operation of `cat_diag` and should hence
...
...
torch_sparse/permute.py
0 → 100644
View file @
bfb571cb
import
torch
from
torch_sparse.tensor
import
SparseTensor
def
permute
(
src
:
SparseTensor
,
perm
:
torch
.
Tensor
)
->
SparseTensor
:
assert
src
.
is_quadratic
()
return
src
.
index_select
(
0
,
perm
).
index_select
(
1
,
perm
)
SparseTensor
.
permute
=
lambda
self
,
perm
:
permute
(
self
,
perm
)
torch_sparse/reduce.py
View file @
bfb571cb
...
...
@@ -5,7 +5,6 @@ from torch_scatter import scatter, segment_csr
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
reduction
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
,
reduce
:
str
=
'sum'
)
->
torch
.
Tensor
:
value
=
src
.
storage
.
value
()
...
...
@@ -68,22 +67,18 @@ def reduction(src: SparseTensor, dim: Optional[int] = None,
raise
ValueError
@
torch
.
jit
.
script
def
sum
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'sum'
)
@
torch
.
jit
.
script
def
mean
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'mean'
)
@
torch
.
jit
.
script
def
min
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'min'
)
@
torch
.
jit
.
script
def
max
(
src
:
SparseTensor
,
dim
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
reduction
(
src
,
dim
,
reduce
=
'max'
)
...
...
torch_sparse/select.py
View file @
bfb571cb
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.narrow
import
narrow
@
torch
.
jit
.
script
def
select
(
src
:
SparseTensor
,
dim
:
int
,
idx
:
int
)
->
SparseTensor
:
return
narrow
(
src
,
dim
,
start
=
idx
,
length
=
1
)
...
...
torch_sparse/spspmm.py
View file @
bfb571cb
...
...
@@ -23,9 +23,9 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
"""
A
=
SparseTensor
(
row
=
indexA
[
0
],
col
=
indexA
[
1
],
value
=
valueA
,
sparse_sizes
=
torch
.
Size
([
m
,
k
]
),
is_sorted
=
not
coalesced
)
sparse_sizes
=
(
m
,
k
),
is_sorted
=
not
coalesced
)
B
=
SparseTensor
(
row
=
indexB
[
0
],
col
=
indexB
[
1
],
value
=
valueB
,
sparse_sizes
=
torch
.
Size
([
k
,
n
]
),
is_sorted
=
not
coalesced
)
sparse_sizes
=
(
k
,
n
),
is_sorted
=
not
coalesced
)
C
=
matmul
(
A
,
B
)
row
,
col
,
value
=
C
.
coo
()
...
...
torch_sparse/storage.py
View file @
bfb571cb
import
warnings
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Tuple
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
...
...
@@ -23,19 +23,18 @@ class SparseStorage(object):
_rowptr
:
Optional
[
torch
.
Tensor
]
_col
:
torch
.
Tensor
_value
:
Optional
[
torch
.
Tensor
]
_sparse_sizes
:
List
[
int
]
_sparse_sizes
:
Tuple
[
int
,
int
]
_rowcount
:
Optional
[
torch
.
Tensor
]
_colptr
:
Optional
[
torch
.
Tensor
]
_colcount
:
Optional
[
torch
.
Tensor
]
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
List
[
int
]]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -57,7 +56,7 @@ class SparseStorage(object):
else
:
raise
ValueError
N
=
col
.
max
().
item
()
+
1
sparse_sizes
=
torch
.
Size
([
int
(
M
),
int
(
N
)
]
)
sparse_sizes
=
(
int
(
M
),
int
(
N
))
else
:
assert
len
(
sparse_sizes
)
==
2
...
...
@@ -119,7 +118,7 @@ class SparseStorage(object):
self
.
_rowptr
=
rowptr
self
.
_col
=
col
self
.
_value
=
value
self
.
_sparse_sizes
=
sparse_sizes
self
.
_sparse_sizes
=
tuple
(
sparse_sizes
)
self
.
_rowcount
=
rowcount
self
.
_colptr
=
colptr
self
.
_colcount
=
colcount
...
...
@@ -192,8 +191,7 @@ class SparseStorage(object):
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_value
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc'
:
...
...
@@ -205,8 +203,7 @@ class SparseStorage(object):
self
.
_value
=
value
return
self
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
if
value
is
not
None
:
if
get_layout
(
layout
)
==
'csc'
:
...
...
@@ -215,26 +212,19 @@ class SparseStorage(object):
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
sparse_sizes
(
self
)
->
List
[
int
]:
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
sparse_sizes
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
_sparse_sizes
def
sparse_size
(
self
,
dim
:
int
)
->
int
:
return
self
.
_sparse_sizes
[
dim
]
def
sparse_resize
(
self
,
sparse_sizes
:
List
[
int
]):
def
sparse_resize
(
self
,
sparse_sizes
:
Tuple
[
int
,
int
]):
assert
len
(
sparse_sizes
)
==
2
old_sparse_sizes
,
nnz
=
self
.
_sparse_sizes
,
self
.
_col
.
numel
()
...
...
@@ -264,18 +254,11 @@ class SparseStorage(object):
if
colcount
is
not
None
:
colcount
=
colcount
[:
-
diff_1
]
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
...
...
@@ -320,10 +303,8 @@ class SparseStorage(object):
if
colptr
is
not
None
:
colcount
=
colptr
[
1
:]
-
colptr
[:
-
1
]
else
:
colcount
=
scatter_add
(
torch
.
ones_like
(
self
.
_col
),
self
.
_col
,
dim_size
=
self
.
_sparse_sizes
[
1
])
colcount
=
scatter_add
(
torch
.
ones_like
(
self
.
_col
),
self
.
_col
,
dim_size
=
self
.
_sparse_sizes
[
1
])
self
.
_colcount
=
colcount
return
colcount
...
...
@@ -375,18 +356,10 @@ class SparseStorage(object):
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
value
[
0
]
if
isinstance
(
value
,
tuple
)
else
value
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
def
fill_cache_
(
self
):
self
.
row
()
...
...
@@ -406,33 +379,30 @@ class SparseStorage(object):
self
.
_csc2csr
=
None
return
self
def
num_
cached_keys
(
self
)
->
int
:
count
=
0
def
cached_keys
(
self
)
->
List
[
str
]
:
keys
:
List
[
str
]
=
[]
if
self
.
has_rowcount
():
count
+=
1
keys
.
append
(
'rowcount'
)
if
self
.
has_colptr
():
count
+=
1
keys
.
append
(
'colptr'
)
if
self
.
has_colcount
():
count
+=
1
keys
.
append
(
'colcount'
)
if
self
.
has_csr2csc
():
count
+=
1
keys
.
append
(
'csr2csc'
)
if
self
.
has_csc2csr
():
count
+=
1
return
count
keys
.
append
(
'csc2csr'
)
return
keys
def
num_cached_keys
(
self
)
->
int
:
return
len
(
self
.
cached_keys
())
def
copy
(
self
):
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
clone
(
self
):
row
=
self
.
_row
...
...
@@ -460,18 +430,11 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
clone
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
type_as
(
self
,
tensor
=
torch
.
Tensor
):
value
=
self
.
_value
...
...
@@ -512,18 +475,11 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
to
(
tensor
.
device
,
non_blocking
=
non_blocking
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
pin_memory
(
self
):
row
=
self
.
_row
...
...
@@ -551,18 +507,11 @@ class SparseStorage(object):
csc2csr
=
self
.
_csc2csr
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
pin_memory
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
is_pinned
(
self
)
->
bool
:
is_pinned
=
True
...
...
torch_sparse/tensor.py
View file @
bfb571cb
...
...
@@ -16,7 +16,8 @@ class SparseTensor(object):
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
List
[
int
]
=
None
,
is_sorted
:
bool
=
False
):
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
is_sorted
:
bool
=
False
):
self
.
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
...
...
@@ -45,7 +46,8 @@ class SparseTensor(object):
value
=
mat
[
row
,
col
]
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
mat
.
size
()[:
2
],
is_sorted
=
True
)
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
@
classmethod
def
from_torch_sparse_coo_tensor
(
self
,
mat
:
torch
.
Tensor
,
...
...
@@ -59,7 +61,8 @@ class SparseTensor(object):
value
=
mat
.
_values
()
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
mat
.
size
()[:
2
],
is_sorted
=
True
)
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
@
classmethod
def
eye
(
self
,
M
:
int
,
N
:
Optional
[
int
]
=
None
,
...
...
@@ -105,10 +108,9 @@ class SparseTensor(object):
csr2csc
=
csc2csr
=
row
storage
:
SparseStorage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
torch
.
Size
([
M
,
N
]),
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
M
,
N
),
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
self
=
SparseTensor
.
__new__
(
SparseTensor
)
self
.
storage
=
storage
...
...
@@ -160,13 +162,13 @@ class SparseTensor(object):
layout
:
Optional
[
str
]
=
None
):
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
))
def
sparse_sizes
(
self
)
->
List
[
int
]:
def
sparse_sizes
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
storage
.
sparse_sizes
()
def
sparse_size
(
self
,
dim
:
int
)
->
int
:
return
self
.
storage
.
sparse_sizes
()[
dim
]
def
sparse_resize
(
self
,
sparse_sizes
:
List
[
int
]):
def
sparse_resize
(
self
,
sparse_sizes
:
Tuple
[
int
,
int
]):
return
self
.
from_storage
(
self
.
storage
.
sparse_resize
(
sparse_sizes
))
def
is_coalesced
(
self
)
->
bool
:
...
...
@@ -206,11 +208,12 @@ class SparseTensor(object):
return
self
.
set_value
(
value
,
layout
=
'coo'
)
def
sizes
(
self
)
->
List
[
int
]:
sizes
=
self
.
sparse_sizes
()
sparse_
sizes
=
self
.
sparse_sizes
()
value
=
self
.
storage
.
value
()
if
value
is
not
None
:
sizes
=
list
(
sizes
)
+
list
(
value
.
size
())[
1
:]
return
sizes
return
list
(
sparse_sizes
)
+
list
(
value
.
size
())[
1
:]
else
:
return
list
(
sparse_sizes
)
def
size
(
self
,
dim
:
int
)
->
int
:
return
self
.
sizes
()[
dim
]
...
...
@@ -268,7 +271,7 @@ class SparseTensor(object):
N
=
max
(
self
.
size
(
0
),
self
.
size
(
1
))
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
torch
.
Size
([
N
,
N
]
),
is_sorted
=
False
)
sparse_sizes
=
(
N
,
N
),
is_sorted
=
False
)
out
=
out
.
coalesce
(
reduce
)
return
out
...
...
torch_sparse/transpose.py
View file @
bfb571cb
...
...
@@ -4,7 +4,6 @@ from torch_sparse.storage import SparseStorage
from
torch_sparse.tensor
import
SparseTensor
@
torch
.
jit
.
script
def
t
(
src
:
SparseTensor
)
->
SparseTensor
:
csr2csc
=
src
.
storage
.
csr2csc
()
...
...
@@ -20,7 +19,7 @@ def t(src: SparseTensor) -> SparseTensor:
rowptr
=
src
.
storage
.
_colptr
,
col
=
row
[
csr2csc
],
value
=
value
,
sparse_sizes
=
torch
.
Size
([
sparse_sizes
[
1
],
sparse_sizes
[
0
]
]
),
sparse_sizes
=
(
sparse_sizes
[
1
],
sparse_sizes
[
0
]),
rowcount
=
src
.
storage
.
_colcount
,
colptr
=
src
.
storage
.
_rowptr
,
colcount
=
src
.
storage
.
_rowcount
,
...
...
@@ -54,7 +53,7 @@ def transpose(index, value, m, n, coalesced=True):
row
,
col
=
col
,
row
if
coalesced
:
sparse_sizes
=
torch
.
Size
([
n
,
m
]
)
sparse_sizes
=
(
n
,
m
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
is_sorted
=
False
)
storage
=
storage
.
coalesce
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment