Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
df2ed804
Commit
df2ed804
authored
Apr 27, 2018
by
rusty1s
Browse files
graclus done
parent
dcd88f5a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
19 deletions
+40
-19
aten/cpu/cluster.cpp
aten/cpu/cluster.cpp
+36
-17
aten/cpu/cluster.py
aten/cpu/cluster.py
+4
-2
No files found.
aten/cpu/cluster.cpp
View file @
df2ed804
#include <torch/torch.h>
inline
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
randperm
(
at
::
Tensor
row
,
at
::
Tensor
col
)
{
/* at::Tensor perm; */
/* std::tie(row, perm) = row.sort(); */
/* col = col.index_select(0, perm); */
/* TODO: randperm */
/* TODO: randperm_sort_row */
return
{
row
,
col
};
inline
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
remove_self_loops
(
at
::
Tensor
row
,
at
::
Tensor
col
)
{
auto
mask
=
row
!=
col
;
row
=
row
.
masked_select
(
mask
);
col
=
col
.
masked_select
(
mask
);
return
{
row
,
col
};
}
inline
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
randperm
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
num_nodes
)
{
// Randomly reorder row and column indices.
auto
perm
=
at
::
randperm
(
torch
::
CPU
(
at
::
kLong
),
row
.
size
(
0
));
row
=
row
.
index_select
(
0
,
perm
);
col
=
col
.
index_select
(
0
,
perm
);
// Randomly swap row values.
auto
node_rid
=
at
::
randperm
(
torch
::
CPU
(
at
::
kLong
),
num_nodes
);
row
=
node_rid
.
index_select
(
0
,
row
);
// Sort row and column indices row-wise.
std
::
tie
(
row
,
perm
)
=
row
.
sort
();
col
=
col
.
index_select
(
0
,
perm
);
// Revert row value swaps.
row
=
std
::
get
<
1
>
(
node_rid
.
sort
()).
index_select
(
0
,
row
);
return
{
row
,
col
};
}
inline
at
::
Tensor
degree
(
at
::
Tensor
index
,
int64_t
num_nodes
)
{
auto
zero
=
at
::
zeros
(
torch
::
CPU
(
at
::
kLong
),
{
num_nodes
});
auto
zero
=
at
::
zeros
(
torch
::
CPU
(
at
::
kLong
),
{
num_nodes
});
return
zero
.
scatter_add_
(
0
,
index
,
at
::
ones_like
(
index
));
}
at
::
Tensor
graclus
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
num_nodes
)
{
std
::
tie
(
row
,
col
)
=
randperm
(
row
,
col
);
std
::
tie
(
row
,
col
)
=
remove_self_loops
(
row
,
col
);
std
::
tie
(
row
,
col
)
=
randperm
(
row
,
col
,
num_nodes
);
auto
deg
=
degree
(
row
,
num_nodes
);
auto
cluster
=
at
::
empty
(
torch
::
CPU
(
at
::
kLong
),
{
num_nodes
}).
fill_
(
-
1
);
auto
cluster
=
at
::
empty
(
torch
::
CPU
(
at
::
kLong
),
{
num_nodes
}).
fill_
(
-
1
);
auto
*
row_data
=
row
.
data
<
int64_t
>
();
auto
*
col_data
=
col
.
data
<
int64_t
>
();
auto
*
deg_data
=
deg
.
data
<
int64_t
>
();
auto
*
cluster_data
=
cluster
.
data
<
int64_t
>
();
int64_t
n_idx
=
0
,
e_idx
=
0
,
d_idx
,
r
,
c
;
int64_t
e_idx
=
0
,
d_idx
,
r
,
c
;
while
(
e_idx
<
row
.
size
(
0
))
{
r
=
row_data
[
e_idx
];
if
(
cluster_data
[
r
]
<
0
)
{
...
...
@@ -42,8 +62,7 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
}
}
}
e_idx
+=
deg_data
[
n_idx
];
n_idx
++
;
e_idx
+=
deg_data
[
r
];
}
return
cluster
;
...
...
@@ -55,15 +74,15 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
start
=
start
.
toType
(
pos
.
type
());
end
=
end
.
toType
(
pos
.
type
());
pos
=
pos
-
start
.
view
({
1
,
-
1
});
pos
=
pos
-
start
.
view
({
1
,
-
1
});
auto
num_voxels
=
((
end
-
start
)
/
size
).
toType
(
at
::
kLong
);
num_voxels
=
(
num_voxels
+
1
).
cumsum
(
0
);
num_voxels
-=
num_voxels
.
data
<
int64_t
>
()[
0
];
num_voxels
.
data
<
int64_t
>
()[
0
]
=
1
;
auto
cluster
=
pos
/
size
.
view
({
1
,
-
1
});
auto
cluster
=
pos
/
size
.
view
({
1
,
-
1
});
cluster
=
cluster
.
toType
(
at
::
kLong
);
cluster
*=
num_voxels
.
view
({
1
,
-
1
});
cluster
*=
num_voxels
.
view
({
1
,
-
1
});
cluster
=
cluster
.
sum
(
1
);
return
cluster
;
...
...
aten/cpu/cluster.py
View file @
df2ed804
...
...
@@ -25,5 +25,7 @@ def graclus_cluster(row, col, num_nodes):
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
1
,
2
,
2
,
2
,
3
,
3
])
col
=
torch
.
tensor
([
1
,
2
,
0
,
2
,
3
,
0
,
1
,
3
,
1
,
2
])
print
(
row
)
print
(
graclus_cluster
(
row
,
col
,
4
))
print
(
col
)
print
(
'-----------------'
)
cluster
=
graclus_cluster
(
row
,
col
,
4
)
print
(
cluster
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment