Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
6a2e1a08
Commit
6a2e1a08
authored
Feb 20, 2018
by
rusty1s
Browse files
rename
parent
fb3d8a81
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
7 deletions
+7
-7
test/test_grid.py
test/test_grid.py
+2
-2
torch_cluster/functions/grid.py
torch_cluster/functions/grid.py
+5
-5
No files found.
test/test_grid.py
View file @
6a2e1a08
...
@@ -14,7 +14,7 @@ def test_grid_cluster_cpu(tensor):
...
@@ -14,7 +14,7 @@ def test_grid_cluster_cpu(tensor):
assert
output
.
tolist
()
==
expected
.
tolist
()
assert
output
.
tolist
()
==
expected
.
tolist
()
expected
=
torch
.
LongTensor
([
0
,
1
])
expected
=
torch
.
LongTensor
([
0
,
1
])
output
,
_
=
grid_cluster
(
position
,
size
,
o
ffset
=
0
)
output
,
_
=
grid_cluster
(
position
,
size
,
o
rigin
=
0
)
assert
output
.
tolist
()
==
expected
.
tolist
()
assert
output
.
tolist
()
==
expected
.
tolist
()
position
=
Tensor
(
tensor
,
[
0
,
17
,
2
,
8
,
3
])
position
=
Tensor
(
tensor
,
[
0
,
17
,
2
,
8
,
3
])
...
@@ -63,7 +63,7 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
...
@@ -63,7 +63,7 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
assert
output
.
cpu
().
tolist
()
==
expected
.
tolist
()
assert
output
.
cpu
().
tolist
()
==
expected
.
tolist
()
expected
=
torch
.
LongTensor
([
0
,
1
])
expected
=
torch
.
LongTensor
([
0
,
1
])
output
,
_
=
grid_cluster
(
position
,
size
,
o
ffset
=
0
)
output
,
_
=
grid_cluster
(
position
,
size
,
o
rigin
=
0
)
assert
output
.
cpu
().
tolist
()
==
expected
.
tolist
()
assert
output
.
cpu
().
tolist
()
==
expected
.
tolist
()
position
=
Tensor
(
tensor
,
[
0
,
17
,
2
,
8
,
3
]).
cuda
()
position
=
Tensor
(
tensor
,
[
0
,
17
,
2
,
8
,
3
]).
cuda
()
...
...
torch_cluster/functions/grid.py
View file @
6a2e1a08
...
@@ -3,7 +3,7 @@ import torch
...
@@ -3,7 +3,7 @@ import torch
from
.utils
import
get_func
,
consecutive
from
.utils
import
get_func
,
consecutive
def
grid_cluster
(
position
,
size
,
batch
=
None
,
o
ffset
=
None
,
fake_nodes
=
False
):
def
grid_cluster
(
position
,
size
,
batch
=
None
,
o
rigin
=
None
,
fake_nodes
=
False
):
# Allow one-dimensional positions.
# Allow one-dimensional positions.
if
position
.
dim
()
==
1
:
if
position
.
dim
()
==
1
:
position
=
position
.
unsqueeze
(
-
1
)
position
=
position
.
unsqueeze
(
-
1
)
...
@@ -21,14 +21,14 @@ def grid_cluster(position, size, batch=None, offset=None, fake_nodes=False):
...
@@ -21,14 +21,14 @@ def grid_cluster(position, size, batch=None, offset=None, fake_nodes=False):
position
=
torch
.
cat
([
batch
,
position
],
dim
=-
1
)
position
=
torch
.
cat
([
batch
,
position
],
dim
=-
1
)
size
=
torch
.
cat
([
size
.
new
(
1
).
fill_
(
1
),
size
],
dim
=-
1
)
size
=
torch
.
cat
([
size
.
new
(
1
).
fill_
(
1
),
size
],
dim
=-
1
)
# Translate to minimal positive positions if no o
ffset i
s passed.
# Translate to minimal positive positions if no o
rigin wa
s passed.
if
o
ffset
is
None
:
if
o
rigin
is
None
:
min
=
position
.
min
(
dim
=-
2
,
keepdim
=
True
)[
0
]
min
=
position
.
min
(
dim
=-
2
,
keepdim
=
True
)[
0
]
position
=
position
-
min
position
=
position
-
min
else
:
else
:
position
=
position
+
o
ffset
position
=
position
+
o
rigin
assert
position
.
min
()
>=
0
,
(
assert
position
.
min
()
>=
0
,
(
'Passed o
ffset
resulting in unallowed negative positions'
)
'Passed o
rigin
resulting in unallowed negative positions'
)
# Compute cluster count for each dimension.
# Compute cluster count for each dimension.
max
=
position
.
max
(
dim
=
0
)[
0
]
max
=
position
.
max
(
dim
=
0
)[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment