Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
b0ff709e
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fe5f035f797a5fa663a98030c9d0ec2f982cd09d"
Commit
b0ff709e
authored
Mar 11, 2022
by
rusty1s
Browse files
torch_csr_tensor
parent
fcf15650
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
27 deletions
+58
-27
test/test_matmul.py
test/test_matmul.py
+2
-3
torch_sparse/matmul.py
torch_sparse/matmul.py
+1
-0
torch_sparse/tensor.py
torch_sparse/tensor.py
+55
-24
No files found.
test/test_matmul.py
View file @
b0ff709e
...
@@ -2,12 +2,11 @@ from itertools import product
...
@@ -2,12 +2,11 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
import
torch_scatter
from
torch_sparse.matmul
import
matmul
from
torch_sparse.matmul
import
matmul
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
import
torch_scatter
from
.utils
import
reductions
,
devices
,
grad_dtypes
from
.utils
import
devices
,
grad_dtypes
,
reductions
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
@
pytest
.
mark
.
parametrize
(
'dtype,device,reduce'
,
...
...
torch_sparse/matmul.py
View file @
b0ff709e
from
typing
import
Tuple
from
typing
import
Tuple
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
...
...
torch_sparse/tensor.py
View file @
b0ff709e
from
textwrap
import
indent
from
textwrap
import
indent
from
typing
import
Optional
,
List
,
Tuple
,
Dict
,
Union
,
Any
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
numpy
as
np
import
numpy
as
np
import
scipy.sparse
import
scipy.sparse
import
torch
from
torch_scatter
import
segment_csr
from
torch_scatter
import
segment_csr
from
torch_sparse.storage
import
SparseStorage
,
get_layout
from
torch_sparse.storage
import
SparseStorage
,
get_layout
...
@@ -13,14 +13,16 @@ from torch_sparse.storage import SparseStorage, get_layout
...
@@ -13,14 +13,16 @@ from torch_sparse.storage import SparseStorage, get_layout
class
SparseTensor
(
object
):
class
SparseTensor
(
object
):
storage
:
SparseStorage
storage
:
SparseStorage
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
def
__init__
(
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
self
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
col
:
Optional
[
torch
.
Tensor
]
=
None
,
Optional
[
int
]]]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
is_sorted
:
bool
=
False
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
Optional
[
int
]]]
=
None
,
trust_data
:
bool
=
False
):
is_sorted
:
bool
=
False
,
trust_data
:
bool
=
False
,
):
self
.
storage
=
SparseStorage
(
self
.
storage
=
SparseStorage
(
row
=
row
,
row
=
row
,
rowptr
=
rowptr
,
rowptr
=
rowptr
,
...
@@ -33,7 +35,8 @@ class SparseTensor(object):
...
@@ -33,7 +35,8 @@ class SparseTensor(object):
csr2csc
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
csc2csr
=
None
,
is_sorted
=
is_sorted
,
is_sorted
=
is_sorted
,
trust_data
=
trust_data
)
trust_data
=
trust_data
,
)
@
classmethod
@
classmethod
def
from_storage
(
self
,
storage
:
SparseStorage
):
def
from_storage
(
self
,
storage
:
SparseStorage
):
...
@@ -44,7 +47,8 @@ class SparseTensor(object):
...
@@ -44,7 +47,8 @@ class SparseTensor(object):
value
=
storage
.
_value
,
value
=
storage
.
_value
,
sparse_sizes
=
storage
.
_sparse_sizes
,
sparse_sizes
=
storage
.
_sparse_sizes
,
is_sorted
=
True
,
is_sorted
=
True
,
trust_data
=
True
)
trust_data
=
True
,
)
out
.
storage
.
_rowcount
=
storage
.
_rowcount
out
.
storage
.
_rowcount
=
storage
.
_rowcount
out
.
storage
.
_colptr
=
storage
.
_colptr
out
.
storage
.
_colptr
=
storage
.
_colptr
out
.
storage
.
_colcount
=
storage
.
_colcount
out
.
storage
.
_colcount
=
storage
.
_colcount
...
@@ -53,12 +57,14 @@ class SparseTensor(object):
...
@@ -53,12 +57,14 @@ class SparseTensor(object):
return
out
return
out
@
classmethod
@
classmethod
def
from_edge_index
(
self
,
edge_index
:
torch
.
Tensor
,
def
from_edge_index
(
edge_attr
:
Optional
[
torch
.
Tensor
]
=
None
,
self
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
edge_index
:
torch
.
Tensor
,
Optional
[
int
]]]
=
None
,
edge_attr
:
Optional
[
torch
.
Tensor
]
=
None
,
is_sorted
:
bool
=
False
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
Optional
[
int
]]]
=
None
,
trust_data
:
bool
=
False
):
is_sorted
:
bool
=
False
,
trust_data
:
bool
=
False
,
):
return
SparseTensor
(
row
=
edge_index
[
0
],
rowptr
=
None
,
col
=
edge_index
[
1
],
return
SparseTensor
(
row
=
edge_index
[
0
],
rowptr
=
None
,
col
=
edge_index
[
1
],
value
=
edge_attr
,
sparse_sizes
=
sparse_sizes
,
value
=
edge_attr
,
sparse_sizes
=
sparse_sizes
,
is_sorted
=
is_sorted
,
trust_data
=
trust_data
)
is_sorted
=
is_sorted
,
trust_data
=
trust_data
)
...
@@ -97,6 +103,20 @@ class SparseTensor(object):
...
@@ -97,6 +103,20 @@ class SparseTensor(object):
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
,
trust_data
=
True
)
is_sorted
=
True
,
trust_data
=
True
)
@
classmethod
def
from_torch_sparse_csr_tensor
(
self
,
mat
:
torch
.
Tensor
,
has_value
:
bool
=
True
):
rowptr
=
mat
.
crow_indices
()
col
=
mat
.
col_indices
()
value
:
Optional
[
torch
.
Tensor
]
=
None
if
has_value
:
value
=
mat
.
values
()
return
SparseTensor
(
row
=
None
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
,
trust_data
=
True
)
@
classmethod
@
classmethod
def
eye
(
self
,
M
:
int
,
N
:
Optional
[
int
]
=
None
,
has_value
:
bool
=
True
,
def
eye
(
self
,
M
:
int
,
N
:
Optional
[
int
]
=
None
,
has_value
:
bool
=
True
,
dtype
:
Optional
[
int
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
int
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
...
@@ -140,7 +160,8 @@ class SparseTensor(object):
...
@@ -140,7 +160,8 @@ class SparseTensor(object):
value
=
value
,
value
=
value
,
sparse_sizes
=
(
M
,
N
),
sparse_sizes
=
(
M
,
N
),
is_sorted
=
True
,
is_sorted
=
True
,
trust_data
=
True
)
trust_data
=
True
,
)
out
.
storage
.
_rowcount
=
rowcount
out
.
storage
.
_rowcount
=
rowcount
out
.
storage
.
_colptr
=
colptr
out
.
storage
.
_colptr
=
colptr
out
.
storage
.
_colcount
=
colcount
out
.
storage
.
_colcount
=
colcount
...
@@ -158,8 +179,8 @@ class SparseTensor(object):
...
@@ -158,8 +179,8 @@ class SparseTensor(object):
value
=
self
.
storage
.
value
()
value
=
self
.
storage
.
value
()
if
value
is
None
or
dtype
==
value
.
dtype
:
if
value
is
None
or
dtype
==
value
.
dtype
:
return
self
return
self
return
self
.
from_storage
(
self
.
storage
.
type
(
return
self
.
from_storage
(
dtype
=
dtype
,
non_blocking
=
non_blocking
))
self
.
storage
.
type
(
dtype
=
dtype
,
non_blocking
=
non_blocking
))
def
type_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
def
type_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
return
self
.
type
(
dtype
=
tensor
.
dtype
,
non_blocking
=
non_blocking
)
return
self
.
type
(
dtype
=
tensor
.
dtype
,
non_blocking
=
non_blocking
)
...
@@ -167,8 +188,8 @@ class SparseTensor(object):
...
@@ -167,8 +188,8 @@ class SparseTensor(object):
def
to_device
(
self
,
device
:
torch
.
device
,
non_blocking
:
bool
=
False
):
def
to_device
(
self
,
device
:
torch
.
device
,
non_blocking
:
bool
=
False
):
if
device
==
self
.
device
():
if
device
==
self
.
device
():
return
self
return
self
return
self
.
from_storage
(
self
.
storage
.
to_device
(
return
self
.
from_storage
(
device
=
device
,
non_blocking
=
non_blocking
))
self
.
storage
.
to_device
(
device
=
device
,
non_blocking
=
non_blocking
))
def
device_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
def
device_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
return
self
.
to_device
(
device
=
tensor
.
device
,
non_blocking
=
non_blocking
)
return
self
.
to_device
(
device
=
tensor
.
device
,
non_blocking
=
non_blocking
)
...
@@ -362,7 +383,8 @@ class SparseTensor(object):
...
@@ -362,7 +383,8 @@ class SparseTensor(object):
value
=
value
,
value
=
value
,
sparse_sizes
=
(
N
,
N
),
sparse_sizes
=
(
N
,
N
),
is_sorted
=
True
,
is_sorted
=
True
,
trust_data
=
True
)
trust_data
=
True
,
)
return
out
return
out
def
detach_
(
self
):
def
detach_
(
self
):
...
@@ -479,6 +501,15 @@ class SparseTensor(object):
...
@@ -479,6 +501,15 @@ class SparseTensor(object):
return
torch
.
sparse_coo_tensor
(
index
,
value
,
self
.
sizes
())
return
torch
.
sparse_coo_tensor
(
index
,
value
,
self
.
sizes
())
def
to_torch_sparse_csr_tensor
(
self
,
dtype
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
rowptr
,
col
,
value
=
self
.
csr
()
if
value
is
None
:
value
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
,
device
=
self
.
device
())
return
torch
.
sparse_csr_tensor
(
rowptr
,
col
,
value
,
self
.
sizes
())
# Python Bindings #############################################################
# Python Bindings #############################################################
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment