Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
f0a4d21d
Commit
f0a4d21d
authored
Apr 27, 2018
by
rusty1s
Browse files
none arguments
parent
fa36a835
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
9 deletions
+7
-9
aten/cpu/cluster.cpp
aten/cpu/cluster.cpp
+2
-6
aten/cpu/cluster.py
aten/cpu/cluster.py
+5
-3
No files found.
aten/cpu/cluster.cpp
View file @
f0a4d21d
...
...
@@ -7,9 +7,6 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, at::Tensor weight) {
at
::
Tensor
grid
(
at
::
Tensor
pos
,
at
::
Tensor
size
,
at
::
Tensor
start
,
at
::
Tensor
end
)
{
if
(
!
start
.
defined
())
start
=
std
::
get
<
0
>
(
pos
.
min
(
1
));
if
(
!
end
.
defined
())
end
=
std
::
get
<
0
>
(
pos
.
max
(
1
));
size
=
size
.
toType
(
pos
.
type
());
start
=
start
.
toType
(
pos
.
type
());
end
=
end
.
toType
(
pos
.
type
());
...
...
@@ -30,7 +27,6 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"graclus"
,
&
graclus
,
"Graclus (CPU)"
,
py
::
arg
(
"row"
),
py
::
arg
(
"col"
),
py
::
arg
(
"weight"
));
m
.
def
(
"grid"
,
&
grid
,
"Grid (CPU)"
,
py
::
arg
(
"pos"
),
py
::
arg
(
"size"
),
py
::
arg
(
"start"
),
py
::
arg
(
"end"
));
m
.
def
(
"graclus"
,
&
graclus
,
"Graclus (CPU)"
);
m
.
def
(
"grid"
,
&
grid
,
"Grid (CPU)"
);
}
aten/cpu/cluster.py
View file @
f0a4d21d
...
...
@@ -3,15 +3,17 @@ import torch
import
cluster_cpu
def
grid_cluster
(
pos
,
size
,
start
,
end
):
def
grid_cluster
(
pos
,
size
,
start
=
None
,
end
=
None
):
start
=
pos
.
t
().
min
(
dim
=
1
)[
0
]
if
start
is
None
else
start
end
=
pos
.
t
().
max
(
dim
=
1
)[
0
]
if
end
is
None
else
end
return
cluster_cpu
.
grid
(
pos
,
size
,
start
,
end
)
pos
=
torch
.
tensor
([[
1
,
1
],
[
3
,
3
],
[
5
,
5
],
[
7
,
7
]]
,
dtype
=
torch
.
uint8
)
pos
=
torch
.
tensor
([[
1
,
1
],
[
3
,
3
],
[
5
,
5
],
[
7
,
7
]])
size
=
torch
.
tensor
([
2
,
2
])
start
=
torch
.
tensor
([
0
,
0
])
end
=
torch
.
tensor
([
7
,
7
])
print
(
'pos'
,
pos
.
tolist
())
print
(
'size'
,
size
.
tolist
())
cluster
=
grid_cluster
(
pos
,
size
,
start
,
end
)
cluster
=
grid_cluster
(
pos
,
size
)
print
(
'result'
,
cluster
.
tolist
(),
cluster
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment