Commit 672358fd authored by rusty1s's avatar rusty1s
Browse files

outsource

parent c8338385
#ifndef DEGREE_CPP
#define DEGREE_CPP
#include <torch/torch.h>
inline at::Tensor degree(at::Tensor index, int num_nodes,
at::ScalarType scalar_type) {
auto zero = at::full(torch::CPU(scalar_type), {num_nodes}, 0);
auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one);
}
#endif // DEGREE_CPP
#include <torch/torch.h> #include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, #include "degree.cpp"
at::Tensor col) { #include "loop.cpp"
auto mask = row != col; #include "perm.cpp"
return {row.masked_select(mask), col.masked_select(mask)};
}
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes) {
// Randomly reorder row and column indices.
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
row = row.index_select(0, perm);
col = col.index_select(0, perm);
// Randomly swap row values.
auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
row = node_rid.index_select(0, row);
// Sort row and column indices row-wise.
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
// Revert row value swaps.
row = std::get<1>(node_rid.sort()).index_select(0, row);
return {row, col};
}
inline at::Tensor degree(at::Tensor index, int num_nodes,
at::ScalarType scalar_type) {
auto zero = at::full(torch::CPU(scalar_type), {num_nodes}, 0);
auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one);
}
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) { at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
std::tie(row, col) = remove_self_loops(row, col); std::tie(row, col) = remove_self_loops(row, col);
......
#ifndef LOOP_CPP
#define LOOP_CPP
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
return {row.masked_select(mask), col.masked_select(mask)};
}
#endif // LOOP_CPP
#ifndef PERM_CPP
#define PERM_CPP
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes) {
// Randomly reorder row and column indices.
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
row = row.index_select(0, perm);
col = col.index_select(0, perm);
// Randomly swap row values.
auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
row = node_rid.index_select(0, row);
// Sort row and column indices row-wise.
std::tie(row, perm) = row.sort();
col = col.index_select(0, perm);
// Revert row value swaps.
row = std::get<1>(node_rid.sort()).index_select(0, row);
return {row, col};
}
#endif // PERM_CPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment