Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
7a4a31ae
"docker/diffusers-pytorch-minimum-cuda/Dockerfile" did not exist on "bcb476797ccb7523f3e114f7440b4c8d9bb7154b"
Commit
7a4a31ae
authored
Dec 19, 2019
by
rusty1s
Browse files
typo
parent
2554bf09
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
3 deletions
+8
-3
torch_sparse/index_select.py
torch_sparse/index_select.py
+3
-3
torch_sparse/tensor.py
torch_sparse/tensor.py
+5
-0
No files found.
torch_sparse/index_select.py
View file @
7a4a31ae
...
...
@@ -4,7 +4,7 @@ from torch_sparse.storage import get_layout
import
torch_sparse.arange_interleave_cpu
as
arange_interleave_cpu
def
__
arange_interleave
__
(
start
,
repeat
):
def
arange_interleave
(
start
,
repeat
):
assert
start
.
device
==
repeat
.
device
assert
repeat
.
dtype
==
torch
.
long
assert
start
.
dim
()
==
1
...
...
@@ -29,7 +29,7 @@ def index_select(src, dim, idx):
rowcount
=
rowcount
[
idx
]
tmp
=
torch
.
arange
(
rowcount
.
size
(
0
),
device
=
rowcount
.
device
)
row
=
tmp
.
repeat_interleave
(
rowcount
)
perm
=
__
arange_interleave
__
(
rowptr
[
idx
],
rowcount
)
perm
=
arange_interleave
(
rowptr
[
idx
],
rowcount
)
col
=
col
[
perm
]
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
...
...
@@ -48,7 +48,7 @@ def index_select(src, dim, idx):
colcount
=
colcount
[
idx
]
tmp
=
torch
.
arange
(
colcount
.
size
(
0
),
device
=
row
.
device
)
col
=
tmp
.
repeat_interleave
(
colcount
)
perm
=
__
arange_interleave
__
(
colptr
[
idx
],
colcount
)
perm
=
arange_interleave
(
colptr
[
idx
],
colcount
)
row
=
row
[
perm
]
csc2csr
=
(
colcount
.
size
(
0
)
*
row
+
col
).
argsort
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)[:,
csc2csr
]
...
...
torch_sparse/tensor.py
View file @
7a4a31ae
...
...
@@ -326,6 +326,9 @@ SparseTensor.masked_select_nnz = masked_select_nnz
# # Filter list into edge and sparse slice
# raise NotImplementedError
# def remove_diag(self):
# raise NotImplementedError
# def set_diag(self, value):
# raise NotImplementedError
...
...
@@ -358,6 +361,8 @@ SparseTensor.masked_select_nnz = masked_select_nnz
# raise ValueError('Argument needs to be of type `torch.tensor` or '
# 'type `torch_sparse.SparseTensor`.')
# def add_nnz(self):
# def add(self, other, layout=None):
# if __is_scalar__(other):
# if self.has_value:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment