Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
43fde05a
Commit
43fde05a
authored
Jan 14, 2018
by
rusty1s
Browse files
added assert checks
parent
61084dfe
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
14 deletions
+17
-14
torch_cluster/functions/grid.py
torch_cluster/functions/grid.py
+17
-14
No files found.
torch_cluster/functions/grid.py
View file @
43fde05a
...
...
@@ -4,23 +4,25 @@ from .utils import get_func
def
grid_cluster
(
position
,
size
,
batch
=
None
):
# TODO: Check types and sizes
if
batch
is
not
None
:
batch
=
batch
.
type_as
(
position
)
size
=
torch
.
cat
([
size
.
new
(
1
).
fill_
(
1
),
size
],
dim
=
0
)
dim
=
position
.
dim
()
position
=
torch
.
cat
([
batch
.
unsqueeze
(
dim
-
1
),
position
],
dim
=
dim
-
1
)
# Allow one-dimensional positions.
if
position
.
dim
()
==
1
:
position
=
position
.
unsqueeze
(
-
1
)
dim
=
position
.
dim
()
assert
size
.
dim
()
==
1
,
'Size tensor must be one-dimensional'
assert
position
.
size
(
-
1
)
==
size
.
size
(
-
1
),
(
'Last dimension of position tensor must have same size as size tensor'
)
# Allow one-dimensional positions.
if
dim
==
1
:
position
=
position
.
unsqueeze
(
1
)
dim
+=
1
# If given, append batch to position tensor.
if
batch
is
not
None
:
batch
=
batch
.
unsqueeze
(
-
1
).
type_as
(
position
)
assert
position
.
size
()[:
-
1
]
==
batch
.
size
()[:
-
1
],
(
'Position tensor must have same size as batch tensor apart from '
'the last dimension'
)
position
=
torch
.
cat
([
batch
,
position
],
dim
=-
1
)
size
=
torch
.
cat
([
size
.
new
(
1
).
fill_
(
1
),
size
],
dim
=-
1
)
# Translate to minimal positive positions.
min
=
position
.
min
(
dim
=
dim
-
2
,
keepdim
=
True
)[
0
]
min
=
position
.
min
(
dim
=
-
2
,
keepdim
=
True
)[
0
]
position
=
position
-
min
# Compute cluster count for each dimension.
...
...
@@ -37,8 +39,9 @@ def grid_cluster(position, size, batch=None):
cluster
=
c_max
.
new
(
torch
.
Size
(
s
))
# Fill cluster tensor and reshape.
size
=
size
.
type_as
(
position
)
func
=
get_func
(
'grid'
,
position
)
func
(
C
,
cluster
,
position
,
size
,
c_max
)
cluster
=
cluster
.
squeeze
(
dim
=
dim
-
1
)
cluster
=
cluster
.
squeeze
(
dim
=
-
1
)
return
cluster
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment