Commit f568f19a authored by rusty1s's avatar rusty1s
Browse files

to int

parent e1125f13
...@@ -7,7 +7,7 @@ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, ...@@ -7,7 +7,7 @@ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
} }
inline std::tuple<at::Tensor, at::Tensor> inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) { randperm(at::Tensor row, at::Tensor col, int num_nodes) {
// Randomly reorder row and column indices. // Randomly reorder row and column indices.
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0)); auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
row = row.index_select(0, perm); row = row.index_select(0, perm);
...@@ -27,14 +27,14 @@ randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) { ...@@ -27,14 +27,14 @@ randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
return {row, col}; return {row, col};
} }
inline at::Tensor degree(at::Tensor index, int64_t num_nodes, inline at::Tensor degree(at::Tensor index, int num_nodes,
at::ScalarType scalar_type) { at::ScalarType scalar_type) {
auto zero = at::full(torch::CPU(scalar_type), {num_nodes}, 0); auto zero = at::full(torch::CPU(scalar_type), {num_nodes}, 0);
auto one = at::full(zero.type(), {index.size(0)}, 1); auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one); return zero.scatter_add_(0, index, one);
} }
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
std::tie(row, col) = remove_self_loops(row, col); std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes); std::tie(row, col) = randperm(row, col, num_nodes);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment