Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
eb8c2ec0
Commit
eb8c2ec0
authored
Jun 23, 2020
by
rusty1s
Browse files
cleanup
parent
45a4d985
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
7 deletions
+9
-7
torch_sparse/metis.py
torch_sparse/metis.py
+9
-7
No files found.
torch_sparse/metis.py
View file @
eb8c2ec0
...
@@ -22,6 +22,13 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
...
@@ -22,6 +22,13 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
weighted
=
False
weighted
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
num_parts
>=
1
if
num_parts
==
1
:
partptr
=
torch
.
tensor
([
0
,
src
.
size
(
0
)],
device
=
src
.
device
())
perm
=
torch
.
arange
(
src
.
size
(
0
),
device
=
src
.
device
())
return
src
,
partptr
,
perm
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
,
value
=
src
.
csr
()
rowptr
,
col
=
rowptr
.
cpu
(),
col
.
cpu
()
rowptr
,
col
=
rowptr
.
cpu
(),
col
.
cpu
()
...
@@ -33,13 +40,8 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
...
@@ -33,13 +40,8 @@ def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
else
:
else
:
value
=
None
value
=
None
if
num_parts
>
1
:
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
num_parts
,
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
value
,
recursive
)
num_parts
,
recursive
)
elif
num_parts
==
1
:
cluster
=
torch
.
zeros
((
src
.
size
(
0
)),
dtype
=
torch
.
long
)
else
:
raise
ValueError
cluster
=
cluster
.
to
(
src
.
device
())
cluster
=
cluster
.
to
(
src
.
device
())
cluster
,
perm
=
cluster
.
sort
()
cluster
,
perm
=
cluster
.
sort
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment