"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "e2ee171a4ace19cc5be9c6d65505417303a227e8"
Commit abc17fd1 authored by rusty1s's avatar rusty1s
Browse files

linting

parent db0ffb45
#include <torch/torch.h> #include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, at::Tensor col) { at::Tensor col) {
auto mask = row != col; auto mask = row != col;
row = row.masked_select(mask); row = row.masked_select(mask);
col = col.masked_select(mask); col = col.masked_select(mask);
return {row, col}; return {row, col};
} }
inline std::tuple<at::Tensor, at::Tensor>
inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) { randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
// Randomly reorder row and column indices. // Randomly reorder row and column indices.
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0)); auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
row = row.index_select(0, perm); row = row.index_select(0, perm);
...@@ -29,13 +29,11 @@ inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor co ...@@ -29,13 +29,11 @@ inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor co
return {row, col}; return {row, col};
} }
inline at::Tensor degree(at::Tensor index, int64_t num_nodes) { inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes}); auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
return zero.scatter_add_(0, index, at::ones_like(index)); return zero.scatter_add_(0, index, at::ones_like(index));
} }
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
std::tie(row, col) = remove_self_loops(row, col); std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes); std::tie(row, col) = randperm(row, col, num_nodes);
...@@ -68,8 +66,8 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { ...@@ -68,8 +66,8 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
return cluster; return cluster;
} }
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor end) { at::Tensor end) {
size = size.toType(pos.type()); size = size.toType(pos.type());
start = start.toType(pos.type()); start = start.toType(pos.type());
end = end.toType(pos.type()); end = end.toType(pos.type());
...@@ -88,7 +86,6 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en ...@@ -88,7 +86,6 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
return cluster; return cluster;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CPU)"); m.def("graclus", &graclus, "Graclus (CPU)");
m.def("grid", &grid, "Grid (CPU)"); m.def("grid", &grid, "Grid (CPU)");
......
import os.path as osp import os.path as osp
import shutil
import subprocess import subprocess
import torch import torch
from torch.utils.ffi import create_extension from torch.utils.ffi import create_extension
if osp.exists('build'):
shutil.rmtree('build')
files = ['Graclus', 'Grid'] files = ['Graclus', 'Grid']
headers = ['aten/TH/TH{}.h'.format(f) for f in files] headers = ['aten/TH/TH{}.h'.format(f) for f in files]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment