Commit abc17fd1 authored by rusty1s's avatar rusty1s
Browse files

linting

parent db0ffb45
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, at::Tensor col) {
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col) {
auto mask = row != col;
row = row.masked_select(mask);
col = col.masked_select(mask);
return {row, col};
}
inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
// Randomly reorder row and column indices.
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
row = row.index_select(0, perm);
......@@ -29,13 +29,11 @@ inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor co
return {row, col};
}
inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
return zero.scatter_add_(0, index, at::ones_like(index));
}
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes);
......@@ -68,8 +66,8 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
return cluster;
}
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor end) {
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
size = size.toType(pos.type());
start = start.toType(pos.type());
end = end.toType(pos.type());
......@@ -88,7 +86,6 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
return cluster;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CPU)");
m.def("grid", &grid, "Grid (CPU)");
......
import os.path as osp
import shutil
import subprocess
import torch
from torch.utils.ffi import create_extension
if osp.exists('build'):
shutil.rmtree('build')
files = ['Graclus', 'Grid']
headers = ['aten/TH/TH{}.h'.format(f) for f in files]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment