Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
af7786d8
Commit
af7786d8
authored
Jan 13, 2018
by
rusty1s
Browse files
added python impl
parent
a2f2986a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
4 deletions
+62
-4
torch_cluster/__init__.py
torch_cluster/__init__.py
+3
-1
torch_cluster/functions/grid.py
torch_cluster/functions/grid.py
+47
-0
torch_cluster/functions/utils.py
torch_cluster/functions/utils.py
+8
-0
torch_cluster/src/generic/cpu.c
torch_cluster/src/generic/cpu.c
+4
-3
No files found.
torch_cluster/__init__.py
View file @
af7786d8
from
.functions.grid
import
grid_cluster
__version__
=
'0.1.0'
__all__
=
[
'__version__'
]
__all__
=
[
'grid_cluster'
,
'__version__'
]
torch_cluster/functions/grid.py
0 → 100644
View file @
af7786d8
import
torch
from
.utils
import
get_func
def
grid_cluster
(
position
,
size
,
batch
=
None
):
# TODO: Check types and sizes
print
(
batch
.
type
())
print
(
position
.
type
())
if
batch
is
not
None
:
batch
=
batch
.
type_as
(
position
)
position
=
torch
.
cat
([
position
,
batch
],
dim
=
position
.
dim
()
-
1
)
size
=
torch
.
cat
([
size
,
size
.
new
(
1
).
fill_
(
1
)],
dim
=
0
)
print
(
position
)
# TODO: BATCH
# print(position[0])
# print(position[1])
dim
=
position
.
dim
()
# Allow one-dimensional positions.
if
dim
==
1
:
position
=
position
.
unsqueeze
(
1
)
dim
+=
1
# Translate to minimal positive positions.
min
=
position
.
min
(
dim
=
dim
-
2
,
keepdim
=
True
)[
0
]
position
=
position
-
min
# Compute cluster count for each dimension.
max
=
position
.
max
(
dim
=
0
)[
0
]
while
max
.
dim
()
>
1
:
max
=
max
.
max
(
dim
=
0
)[
0
]
c_max
=
torch
.
ceil
(
max
/
size
.
type_as
(
max
)).
long
()
C
=
c_max
.
prod
()
# Generate cluster tensor.
s
=
list
(
position
.
size
())
s
[
-
1
]
=
1
cluster
=
c_max
.
new
(
torch
.
Size
(
s
))
# Fill cluster tensor and reshape.
func
=
get_func
(
'grid'
,
position
)
func
(
C
,
cluster
,
position
,
size
,
c_max
)
cluster
=
cluster
.
squeeze
(
dim
=
dim
-
1
)
return
cluster
torch_cluster/functions/utils.py
0 → 100644
View file @
af7786d8
from
.._ext
import
ffi
def
get_func
(
name
,
tensor
):
typename
=
type
(
tensor
).
__name__
.
replace
(
'Tensor'
,
''
)
cuda
=
'cuda_'
if
tensor
.
is_cuda
else
''
func
=
getattr
(
ffi
,
'cluster_{}_{}{}'
.
format
(
name
,
cuda
,
typename
))
return
func
torch_cluster/src/generic/cpu.c
View file @
af7786d8
...
...
@@ -5,9 +5,10 @@
void
cluster_
(
grid
)(
int
C
,
THLongTensor
*
output
,
THTensor
*
position
,
THTensor
*
size
,
THLongTensor
*
count
)
{
real
*
size_data
=
size
->
storage
->
data
+
size
->
storageOffset
;
int64_t
*
count_data
=
count
->
storage
->
data
+
count
->
storageOffset
;
int64_t
d
,
i
,
c
,
tmp
;
d
=
THTensor_
(
size
)(
position
,
1
);
TH_TENSOR_DIM_APPLY2
(
int64_t
,
output
,
real
,
position
,
1
,
int64_t
D
,
d
,
i
,
c
,
tmp
;
D
=
THTensor_
(
nDimension
)(
position
);
d
=
THTensor_
(
size
)(
position
,
D
-
1
);
TH_TENSOR_DIM_APPLY2
(
int64_t
,
output
,
real
,
position
,
D
-
1
,
tmp
=
C
;
c
=
0
;
for
(
i
=
0
;
i
<
d
;
i
++
)
{
tmp
=
tmp
/
*
(
count_data
+
i
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment