Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
335dfed0
Commit
335dfed0
authored
Jan 23, 2020
by
rusty1s
Browse files
bugfixes
parent
6a7f10e5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
50 deletions
+21
-50
torch_sparse/storage.py
torch_sparse/storage.py
+3
-2
torch_sparse/tensor.py
torch_sparse/tensor.py
+18
-48
No files found.
torch_sparse/storage.py
View file @
335dfed0
...
@@ -175,6 +175,7 @@ class SparseStorage(object):
...
@@ -175,6 +175,7 @@ class SparseStorage(object):
value
=
torch
.
full
((
self
.
nnz
(),
),
device
=
self
.
index
.
device
)
value
=
torch
.
full
((
self
.
nnz
(),
),
device
=
self
.
index
.
device
)
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
elif
torch
.
is_tensor
(
value
)
and
get_layout
(
layout
)
==
'csc'
:
value
=
value
[
self
.
csc2csr
]
value
=
value
[
self
.
csc2csr
]
if
torch
.
is_tensor
(
value
):
assert
value
.
device
==
self
.
_index
.
device
assert
value
.
device
==
self
.
_index
.
device
assert
value
.
size
(
0
)
==
self
.
_index
.
size
(
1
)
assert
value
.
size
(
0
)
==
self
.
_index
.
size
(
1
)
return
self
.
__class__
(
return
self
.
__class__
(
...
...
torch_sparse/tensor.py
View file @
335dfed0
...
@@ -274,27 +274,32 @@ class SparseTensor(object):
...
@@ -274,27 +274,32 @@ class SparseTensor(object):
return
self
.
from_storage
(
storage
)
return
self
.
from_storage
(
storage
)
def
to
(
self
,
*
args
,
**
kwargs
):
def
to
(
self
,
*
args
,
**
kwargs
):
storage
=
None
args
=
list
(
args
)
non_blocking
=
getattr
(
kwargs
,
'non_blocking'
,
False
)
storage
=
None
if
'device'
in
kwargs
:
if
'device'
in
kwargs
:
device
=
kwargs
[
'device'
]
device
=
kwargs
[
'device'
]
del
kwargs
[
'device'
]
del
kwargs
[
'device'
]
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
to
(
storage
=
self
.
storage
.
apply
(
device
,
non_blocking
=
getattr
(
kwargs
,
'
non_blocking
'
,
False
)
))
lambda
x
:
x
.
to
(
device
,
non_blocking
=
non_blocking
))
else
:
for
arg
in
args
[:]:
for
arg
in
args
[:]:
if
isinstance
(
arg
,
str
)
or
isinstance
(
arg
,
torch
.
device
):
if
isinstance
(
arg
,
str
)
or
isinstance
(
arg
,
torch
.
device
):
storage
=
self
.
storage
.
apply
(
lambda
x
:
x
.
to
(
storage
=
self
.
storage
.
apply
(
arg
,
non_blocking
=
getattr
(
kwargs
,
'
non_blocking
'
,
False
)
))
lambda
x
:
x
.
to
(
arg
,
non_blocking
=
non_blocking
))
args
.
remove
(
arg
)
args
.
remove
(
arg
)
if
storage
is
not
None
:
storage
=
self
.
storage
if
storage
is
None
else
storage
self
=
self
.
from_storage
(
storage
)
if
len
(
args
)
>
0
or
len
(
kwargs
)
>
0
:
if
len
(
args
)
>
0
or
len
(
kwargs
)
>
0
:
s
elf
=
self
.
type
(
*
args
,
**
kwargs
)
s
torage
=
storage
.
apply_value
(
lambda
x
:
x
.
type
(
*
args
,
**
kwargs
)
)
if
storage
==
self
.
storage
:
# Nothing changed...
return
self
return
self
else
:
return
self
.
from_storage
(
storage
)
def
bfloat16
(
self
):
def
bfloat16
(
self
):
return
self
.
type
(
torch
.
bfloat16
)
return
self
.
type
(
torch
.
bfloat16
)
...
@@ -454,41 +459,6 @@ SparseTensor.matmul = matmul
...
@@ -454,41 +459,6 @@ SparseTensor.matmul = matmul
# SparseTensor.add = add
# SparseTensor.add = add
# SparseTensor.add_nnz = add_nnz
# SparseTensor.add_nnz = add_nnz
# def remove_diag(self):
# raise NotImplementedError
# def set_diag(self, value):
# raise NotImplementedError
# def __reduce(self, dim, reduce, only_nnz):
# raise NotImplementedError
# def sum(self, dim):
# return self.__reduce(dim, reduce='add', only_nnz=True)
# def prod(self, dim):
# return self.__reduce(dim, reduce='mul', only_nnz=True)
# def min(self, dim, only_nnz=False):
# return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
# def max(self, dim, only_nnz=False):
# return self.__reduce(dim, reduce='min', only_nnz=only_nnz)
# def mean(self, dim, only_nnz=False):
# return self.__reduce(dim, reduce='mean', only_nnz=only_nnz)
# def matmul(self, mat, reduce='add'):
# assert self.numel() == self.nnz() # Disallow multi-dimensional value
# if torch.is_tensor(mat):
# raise NotImplementedError
# elif isinstance(mat, self.__class__):
# assert reduce == 'add'
# assert mat.numel() == mat.nnz() # Disallow multi-dimensional value
# raise NotImplementedError
# raise ValueError('Argument needs to be of type `torch.tensor` or '
# 'type `torch_sparse.SparseTensor`.')
# def __add__(self, other):
# def __add__(self, other):
# return self.add(other)
# return self.add(other)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment