#pragma once #include std::tuple remove_self_loops(at::Tensor row, at::Tensor col) { auto mask = row != col; return std::make_tuple(row.masked_select(mask), col.masked_select(mask)); } std::tuple remove_self_loops(at::Tensor row, at::Tensor col, at::Tensor weight) { auto mask = row != col; return std::make_tuple(row.masked_select(mask), col.masked_select(mask), weight.masked_select(mask)); } at::Tensor randperm(int64_t n) { auto out = at::empty(n, torch::CPU(at::kLong)); at::randperm_out(out, n); return out; } std::tuple rand(at::Tensor row, at::Tensor col) { auto perm = randperm(row.size(0)); return std::make_tuple(row.index_select(0, perm), col.index_select(0, perm)); } std::tuple sort_by_row(at::Tensor row, at::Tensor col) { at::Tensor perm; std::tie(row, perm) = row.sort(); return std::make_tuple(row, col.index_select(0, perm)); } std::tuple sort_by_row(at::Tensor row, at::Tensor col, at::Tensor weight) { at::Tensor perm; std::tie(row, perm) = row.sort(); return std::make_tuple(row, col.index_select(0, perm), weight.index_select(0, perm)); } at::Tensor degree(at::Tensor row, int64_t num_nodes) { auto zero = zeros(num_nodes, row.options()); auto one = ones(row.size(0), row.options()); return zero.scatter_add_(0, row, one); } std::tuple to_csr(at::Tensor row, at::Tensor col, int64_t num_nodes) { std::tie(row, col) = sort_by_row(row, col); row = degree(row, num_nodes).cumsum(0); row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero. return std::make_tuple(row, col); } std::tuple to_csr(at::Tensor row, at::Tensor col, at::Tensor weight, int64_t num_nodes) { std::tie(row, col, weight) = sort_by_row(row, col, weight); row = degree(row, num_nodes).cumsum(0); row = at::cat({zeros(1, row.options()), row}, 0); // Prepend zero. return std::make_tuple(row, col, weight); }