Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
64b0ae30
Commit
64b0ae30
authored
Jan 27, 2020
by
rusty1s
Browse files
storage done
parent
f59fe649
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
86 additions
and
104 deletions
+86
-104
torch_sparse/storage.py
torch_sparse/storage.py
+86
-104
No files found.
torch_sparse/storage.py
View file @
64b0ae30
import
warnings
import
warnings
from
typing
import
Optional
,
List
,
Dict
,
Any
from
typing
import
Optional
,
List
,
Dict
,
Union
,
Any
import
torch
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_scatter
import
segment_csr
,
scatter_add
from
torch_sparse.utils
import
Final
from
torch_sparse.utils
import
Final
,
is_scalar
__cache__
=
{
'enabled'
:
True
}
#
__cache__ = {'enabled': True}
# def is_cache_enabled():
# return __cache__['enabled']
def
i
s_cache_enabled
():
#
def s
et
_cache_enabled(
mode
):
return
__cache__
[
'enabled'
]
#
__cache__['enabled']
= mode
# class no_cache(object):
# def __enter__(self):
# self.prev = is_cache_enabled()
# set_cache_enabled(False)
def
set_cache_enabled
(
mode
):
# def __exit__(self, *args):
__cache__
[
'enabled'
]
=
mode
# set_cache_enabled(self.prev)
# return False
# def __call__(self, func):
# def decorate_no_cache(*args, **kwargs):
# with self:
# return func(*args, **kwargs)
class
no_cache
(
object
):
# return decorate_no_cache
def
__enter__
(
self
):
self
.
prev
=
is_cache_enabled
()
set_cache_enabled
(
False
)
def
__exit__
(
self
,
*
args
):
set_cache_enabled
(
self
.
prev
)
return
False
def
__call__
(
self
,
func
):
def
decorate_no_cache
(
*
args
,
**
kwargs
):
with
self
:
return
func
(
*
args
,
**
kwargs
)
return
decorate_no_cache
# class cached_property(object):
# def __init__(self, func):
# self.func = func
# def __get__(self, obj, cls):
# value = getattr(obj, f'_{self.func.__name__}', None)
# if value is None:
# value = self.func(obj)
# if is_cache_enabled():
# setattr(obj, f'_{self.func.__name__}', value)
# return value
def
optional
(
func
,
src
):
def
optional
(
func
,
src
):
...
@@ -53,12 +37,12 @@ def optional(func, src):
...
@@ -53,12 +37,12 @@ def optional(func, src):
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
def
get_layout
(
layout
=
None
)
:
def
get_layout
(
layout
:
Optional
[
str
]
=
None
)
->
str
:
if
layout
is
None
:
if
layout
is
None
:
layout
=
'coo'
layout
=
'coo'
warnings
.
warn
(
'`layout` argument unset, using default layout '
warnings
.
warn
(
'`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.'
)
'"coo". This may lead to unexpected behaviour.'
)
assert
layout
in
layouts
assert
layout
==
'coo'
or
layout
==
'csr'
or
layout
==
'csc'
return
layout
return
layout
...
@@ -237,78 +221,79 @@ class SparseStorage(object):
...
@@ -237,78 +221,79 @@ class SparseStorage(object):
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
def
value
(
self
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
_value
return
self
.
_value
# def set_value_(self, value, layout=None, dtype=None):
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
# if isinstance(value, int) or isinstance(value, float):
layout
:
Optional
[
str
]
=
None
):
# value = torch.full((self.col.numel(), ), dtype=dtype,
if
value
is
not
None
:
# device=self.col.device)
if
get_layout
(
layout
)
==
'csc2csr'
:
value
=
value
[
self
.
csc2csr
()]
# elif torch.is_tensor(value) and get_layout(layout) == 'csc':
value
=
value
.
contiguous
()
# value = value[self.csc2csr]
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
# if torch.is_tensor(value):
# value = value if dtype is None else value.to(dtype)
# assert value.device == self.col.device
# assert value.size(0) == self.col.numel()
#
self._value = value
self
.
_value
=
value
#
return self
return
self
# def set_value(self, value, layout=None, dtype=None):
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
# if isinstance(value, int) or isinstance(value, float):
layout
:
Optional
[
str
]
=
None
):
# value = torch.full((self.col.numel(), ), dtype=dtype,
if
value
is
not
None
:
# device=self.col.device)
if
get_layout
(
layout
)
==
'csc2csr'
:
value
=
value
[
self
.
csc2csr
()]
value
=
value
.
contiguous
()
assert
value
.
device
==
self
.
_col
.
device
assert
value
.
size
(
0
)
==
self
.
_col
.
numel
()
# elif torch.is_tensor(value) and get_layout(layout) == 'csc':
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
# value = value[self.csc2csr]
value
=
value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
colcount
=
self
.
_colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
# if torch.is_tensor(value):
def
fill_value_
(
self
,
fill_value
:
float
,
dtype
=
Optional
[
torch
.
dtype
]):
# value = value if dtype is None else value.to(dtype)
value
=
torch
.
empty
(
self
.
_col
.
numel
(),
dtype
,
device
=
self
.
_col
.
device
)
# assert value.device == self.col.device
return
self
.
set_value_
(
value
.
fill_
(
fill_value
),
layout
=
'csr'
)
# assert value.size(0) == self.col.numel()
# return self.__class__(row=self._row, rowptr=self._rowptr, col=self.col,
def
fill_value
(
self
,
fill_value
:
float
,
dtype
=
Optional
[
torch
.
dtype
]):
# value=value, sparse_size=self._sparse_size,
value
=
torch
.
empty
(
self
.
_col
.
numel
(),
dtype
,
device
=
self
.
_col
.
device
)
# rowcount=self._rowcount, colptr=self._colptr,
return
self
.
set_value
(
value
.
fill_
(
fill_value
),
layout
=
'csr'
)
# colcount=self._colcount, csr2csc=self._csr2csc,
# csc2csr=self._csc2csr, is_sorted=True)
def
sparse_size
(
self
)
->
List
[
int
]:
def
sparse_size
(
self
)
->
List
[
int
]:
return
self
.
_sparse_size
return
self
.
_sparse_size
# def sparse_resize(self, *sizes):
def
sparse_resize
(
self
,
sparse_size
:
List
[
int
]):
# old_sparse_size, nnz = self.sparse_size, self.col.numel()
assert
len
(
sparse_size
)
==
2
old_sparse_size
,
nnz
=
self
.
_sparse_size
,
self
.
_col
.
numel
()
# diff_0 = sizes[0] - old_sparse_size[0]
# rowcount, rowptr = self._rowcount, self._rowptr
diff_0
=
sparse_size
[
0
]
-
old_sparse_size
[
0
]
# if diff_0 > 0:
rowcount
,
rowptr
=
self
.
_rowcount
,
self
.
_rowptr
# if rowptr is not None:
if
diff_0
>
0
:
# rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
if
rowptr
is
not
None
:
# if rowcount is not None:
rowptr
=
torch
.
cat
([
rowptr
,
rowptr
.
new_full
((
diff_0
,
),
nnz
)])
# rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
if
rowcount
is
not
None
:
# else:
rowcount
=
torch
.
cat
([
rowcount
,
rowcount
.
new_zeros
(
diff_0
)])
# if rowptr is not None:
else
:
# rowptr = rowptr[:-diff_0]
if
rowptr
is
not
None
:
# if rowcount is not None:
rowptr
=
rowptr
[:
-
diff_0
]
# rowcount = rowcount[:-diff_0]
if
rowcount
is
not
None
:
rowcount
=
rowcount
[:
-
diff_0
]
# diff_1 = sizes[1] - old_sparse_size[1]
# colcount, colptr = self._colcount, self._colptr
diff_1
=
sparse_size
[
1
]
-
old_sparse_size
[
1
]
# if diff_1 > 0:
colcount
,
colptr
=
self
.
_colcount
,
self
.
_colptr
# if colptr is not None:
if
diff_1
>
0
:
# colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
if
colptr
is
not
None
:
# if colcount is not None:
colptr
=
torch
.
cat
([
colptr
,
colptr
.
new_full
((
diff_1
,
),
nnz
)])
# colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
if
colcount
is
not
None
:
# else:
colcount
=
torch
.
cat
([
colcount
,
colcount
.
new_zeros
(
diff_1
)])
# if colptr is not None:
else
:
# colptr = colptr[:-diff_1]
if
colptr
is
not
None
:
# if colcount is not None:
colptr
=
colptr
[:
-
diff_1
]
# colcount = colcount[:-diff_1]
if
colcount
is
not
None
:
colcount
=
colcount
[:
-
diff_1
]
# return self.__class__(row=self._row, rowptr=rowptr, col=self.col,
# value=self.value, sparse_size=sizes,
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
rowptr
,
col
=
self
.
_col
,
# rowcount=rowcount, colptr=colptr,
value
=
self
.
_value
,
sparse_size
=
sparse_size
,
# colcount=colcount, csr2csc=self._csr2csc,
rowcount
=
rowcount
,
colptr
=
colptr
,
# csc2csr=self._csc2csr, is_sorted=True)
colcount
=
colcount
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
)
def
has_rowcount
(
self
)
->
bool
:
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
return
self
.
_rowcount
is
not
None
...
@@ -431,7 +416,7 @@ class SparseStorage(object):
...
@@ -431,7 +416,7 @@ class SparseStorage(object):
self
.
_csc2csr
=
None
self
.
_csc2csr
=
None
return
self
return
self
def
__
copy
__
(
self
):
def
copy
(
self
):
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
return
SparseStorage
(
row
=
self
.
_row
,
rowptr
=
self
.
_rowptr
,
col
=
self
.
_col
,
value
=
self
.
_value
,
sparse_size
=
self
.
_sparse_size
,
value
=
self
.
_value
,
sparse_size
=
self
.
_sparse_size
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
rowcount
=
self
.
_rowcount
,
colptr
=
self
.
_colptr
,
...
@@ -468,6 +453,3 @@ class SparseStorage(object):
...
@@ -468,6 +453,3 @@ class SparseStorage(object):
rowcount
=
rowcount
,
colptr
=
colptr
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
csc2csr
=
csc2csr
,
is_sorted
=
True
)
def
__deepcopy__
(
self
,
memo
:
Dict
[
str
,
Any
]):
return
self
.
clone
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment