Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
a315a06d
Commit
a315a06d
authored
Apr 26, 2018
by
rusty1s
Browse files
first aten try
parent
de54ccec
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
0 deletions
+61
-0
aten/cpu/cluster.cpp
aten/cpu/cluster.cpp
+36
-0
aten/cpu/cluster.py
aten/cpu/cluster.py
+17
-0
aten/cpu/setup.py
aten/cpu/setup.py
+8
-0
No files found.
aten/cpu/cluster.cpp
0 → 100644
View file @
a315a06d
#include <torch/torch.h>
at
::
Tensor
graclus
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
)
{
return
row
;
}
at
::
Tensor
grid
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
)
{
if
(
!
start
.
defined
())
start
=
std
::
get
<
0
>
(
pos
.
min
(
1
));
if
(
!
end
.
defined
())
end
=
std
::
get
<
0
>
(
pos
.
max
(
1
));
size
=
size
.
toType
(
pos
.
type
());
start
=
start
.
toType
(
pos
.
type
());
end
=
end
.
toType
(
pos
.
type
());
pos
=
pos
-
start
.
view
({
1
,
-
1
});
auto
num_voxels
=
((
end
-
start
)
/
size
).
toType
(
at
::
kLong
);
num_voxels
=
(
num_voxels
+
1
).
cumsum
(
0
);
num_voxels
=
num_voxels
-
num_voxels
[
0
];
num_voxels
[
0
]
=
1
;
auto
cluster
=
pos
/
size
.
view
({
1
,
-
1
});
cluster
=
cluster
.
toType
(
at
::
kLong
);
cluster
*=
num_voxels
.
view
({
1
,
-
1
});
cluster
=
cluster
.
sum
(
1
);
return
cluster
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"graclus"
,
&
graclus
,
"Graclus (CPU)"
,
py
::
arg
(
"row"
),
py
::
arg
(
"col"
),
py
::
arg
(
"weight"
));
m
.
def
(
"grid"
,
&
grid
,
"Grid (CPU)"
,
py
::
arg
(
"pos"
),
py
::
arg
(
"size"
),
py
::
arg
(
"start"
),
py
::
arg
(
"end"
));
}
aten/cpu/cluster.py
0 → 100644
View file @
a315a06d
import
torch
import
cluster_cpu
def
grid_cluster
(
pos
,
size
,
start
,
end
):
return
cluster_cpu
.
grid
(
pos
,
size
,
start
,
end
)
pos
=
torch
.
tensor
([[
1
,
1
],
[
3
,
3
],
[
5
,
5
],
[
7
,
7
]],
dtype
=
torch
.
uint8
)
size
=
torch
.
tensor
([
2
,
2
])
start
=
torch
.
tensor
([
0
,
0
])
end
=
torch
.
tensor
([
7
,
7
])
print
(
'pos'
,
pos
.
tolist
())
print
(
'size'
,
size
.
tolist
())
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
print
(
'result'
,
cluster
.
tolist
(),
cluster
.
dtype
)
aten/cpu/setup.py
0 → 100644
View file @
a315a06d
from
setuptools
import
setup
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
setup
(
name
=
'cluster'
,
ext_modules
=
[
CppExtension
(
'cluster_cpu'
,
[
'cluster.cpp'
])],
cmdclass
=
{
'build_ext'
:
BuildExtension
},
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment