Commit 642d548f authored by rusty1s's avatar rusty1s
Browse files

new try

parent fbd14a95
......@@ -12,7 +12,8 @@ def grid(pos, size, start=None, end=None):
def graclus(row, col, num_nodes):
return cluster_cpu.graclus(row, col, num_nodes)
lib = cluster_cuda if pos.is_cuda else cluster_cpu
return lib.graclus(row, col, num_nodes)
device = torch.device('cuda')
......@@ -23,10 +24,11 @@ print('size', size.tolist())
cluster = grid(pos, size)
print('result', cluster.tolist(), cluster.dtype, cluster.device)
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
print(row)
print(col)
print('-----------------')
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3], device=device)
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2], device=device)
print('row', row.tolist())
print('col', col.tolist())
cluster = graclus(row, col, 4)
print(cluster)
print('result', cluster.tolist(), cluster.dtype, cluster.device)
......@@ -7,9 +7,9 @@
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes);
auto deg = degree(row, num_nodes, row.type().scalarType());
auto cluster = at::full(row.type(), {num_nodes}, -1);
auto deg = degree(row, num_nodes, row.type().scalarType());
auto *row_data = row.data<int64_t>();
auto *col_data = col.data<int64_t>();
......
#include <torch/torch.h>
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end);
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
CHECK_CUDA(start);
CHECK_CUDA(end);
return grid_cuda(pos, size, start, end);
}
#include "graclus.cpp"
#include "grid.cpp"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CUDA)");
m.def("grid", &grid, "Grid (CUDA)");
}
#include <torch/torch.h>
#include "../include/degree.cpp"
#include "../include/loop.cpp"
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
CHECK_CUDA(row);
CHECK_CUDA(col);
std::tie(row, col) = remove_self_loops(row, col);
auto deg = degree(row, num_nodes, row.type().scalarType());
return deg;
}
#include <torch/torch.h>
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end);
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
CHECK_CUDA(start);
CHECK_CUDA(end);
return grid_cuda(pos, size, start, end);
}
#include "degree.h"
#ifndef DEGREE_INC
#define DEGREE_INC
#include <torch/torch.h>
......@@ -8,3 +9,5 @@ inline at::Tensor degree(at::Tensor index, int num_nodes,
auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one);
}
#endif // DEGREE_INC
#ifndef DEGREE_INC
#define DEGREE_INC
#include <torch/torch.h>
inline at::Tensor degree(at::Tensor index, int num_nodes,
at::ScalarType scalar_type);
#endif // DEGREE_INC
#include "loop.h"
#ifndef LOOP_INC
#define LOOP_INC
#include <torch/torch.h>
......@@ -7,3 +8,5 @@ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
auto mask = row != col;
return {row.masked_select(mask), col.masked_select(mask)};
}
#endif // LOOP_INC
#ifndef LOOP_INC
#define LOOP_INC
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col);
#endif // LOOP_INC
#include "perm.h"
#ifndef PERM_INC
#define PERM_INC
#include <torch/torch.h>
......@@ -22,3 +23,5 @@ randperm(at::Tensor row, at::Tensor col, int num_nodes) {
return {row, col};
}
#endif // PERM_INC
#ifndef PERM_INC
#define PERM_INC
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes);
#endif // PERM_INC
import torch
import glob
from setuptools import setup
import torch.cuda
from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [CppExtension(name='cluster_cpu', sources=['cpu/cluster.cpp'])]
ext_modules = [CppExtension('cluster_cpu', ['cpu/cluster.cpp'])]
if torch.cuda.is_available():
ext_modules += [
CUDAExtension(
name='cluster_cuda',
sources=['cuda/cluster.cpp', 'cuda/cluster_kernel.cu'])
CUDAExtension('cluster_cuda',
['cuda/cluster.cpp'] + glob.glob('cuda/*.cu'))
]
setup(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment