Commit 642d548f authored by rusty1s's avatar rusty1s
Browse files

new try

parent fbd14a95
...@@ -12,7 +12,8 @@ def grid(pos, size, start=None, end=None): ...@@ -12,7 +12,8 @@ def grid(pos, size, start=None, end=None):
def graclus(row, col, num_nodes): def graclus(row, col, num_nodes):
return cluster_cpu.graclus(row, col, num_nodes) lib = cluster_cuda if pos.is_cuda else cluster_cpu
return lib.graclus(row, col, num_nodes)
device = torch.device('cuda') device = torch.device('cuda')
...@@ -23,10 +24,11 @@ print('size', size.tolist()) ...@@ -23,10 +24,11 @@ print('size', size.tolist())
cluster = grid(pos, size) cluster = grid(pos, size)
print('result', cluster.tolist(), cluster.dtype, cluster.device) print('result', cluster.tolist(), cluster.dtype, cluster.device)
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
print(row)
print(col)
print('-----------------') print('-----------------')
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3], device=device)
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2], device=device)
print('row', row.tolist())
print('col', col.tolist())
cluster = graclus(row, col, 4) cluster = graclus(row, col, 4)
print(cluster) print('result', cluster.tolist(), cluster.dtype, cluster.device)
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) { at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
std::tie(row, col) = remove_self_loops(row, col); std::tie(row, col) = remove_self_loops(row, col);
std::tie(row, col) = randperm(row, col, num_nodes); std::tie(row, col) = randperm(row, col, num_nodes);
auto deg = degree(row, num_nodes, row.type().scalarType());
auto cluster = at::full(row.type(), {num_nodes}, -1); auto cluster = at::full(row.type(), {num_nodes}, -1);
auto deg = degree(row, num_nodes, row.type().scalarType());
auto *row_data = row.data<int64_t>(); auto *row_data = row.data<int64_t>();
auto *col_data = col.data<int64_t>(); auto *col_data = col.data<int64_t>();
......
#include <torch/torch.h> #include <torch/torch.h>
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end);
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, #include "graclus.cpp"
at::Tensor end) { #include "grid.cpp"
CHECK_CUDA(pos);
CHECK_CUDA(size);
CHECK_CUDA(start);
CHECK_CUDA(end);
return grid_cuda(pos, size, start, end);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("graclus", &graclus, "Graclus (CUDA)");
m.def("grid", &grid, "Grid (CUDA)"); m.def("grid", &grid, "Grid (CUDA)");
} }
#include <torch/torch.h>
#include "../include/degree.cpp"
#include "../include/loop.cpp"
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
CHECK_CUDA(row);
CHECK_CUDA(col);
std::tie(row, col) = remove_self_loops(row, col);
auto deg = degree(row, num_nodes, row.type().scalarType());
return deg;
}
#include <torch/torch.h>
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end);
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
CHECK_CUDA(start);
CHECK_CUDA(end);
return grid_cuda(pos, size, start, end);
}
#include "degree.h" #ifndef DEGREE_INC
#define DEGREE_INC
#include <torch/torch.h> #include <torch/torch.h>
...@@ -8,3 +9,5 @@ inline at::Tensor degree(at::Tensor index, int num_nodes, ...@@ -8,3 +9,5 @@ inline at::Tensor degree(at::Tensor index, int num_nodes,
auto one = at::full(zero.type(), {index.size(0)}, 1); auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one); return zero.scatter_add_(0, index, one);
} }
#endif // DEGREE_INC
#ifndef DEGREE_INC
#define DEGREE_INC
#include <torch/torch.h>
inline at::Tensor degree(at::Tensor index, int num_nodes,
at::ScalarType scalar_type);
#endif // DEGREE_INC
#include "loop.h" #ifndef LOOP_INC
#define LOOP_INC
#include <torch/torch.h> #include <torch/torch.h>
...@@ -7,3 +8,5 @@ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, ...@@ -7,3 +8,5 @@ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
auto mask = row != col; auto mask = row != col;
return {row.masked_select(mask), col.masked_select(mask)}; return {row.masked_select(mask), col.masked_select(mask)};
} }
#endif // LOOP_INC
#ifndef LOOP_INC
#define LOOP_INC
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col);
#endif // LOOP_INC
#include "perm.h" #ifndef PERM_INC
#define PERM_INC
#include <torch/torch.h> #include <torch/torch.h>
...@@ -22,3 +23,5 @@ randperm(at::Tensor row, at::Tensor col, int num_nodes) { ...@@ -22,3 +23,5 @@ randperm(at::Tensor row, at::Tensor col, int num_nodes) {
return {row, col}; return {row, col};
} }
#endif // PERM_INC
#ifndef PERM_INC
#define PERM_INC
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes);
#endif // PERM_INC
import torch import glob
from setuptools import setup from setuptools import setup
import torch.cuda
from torch.utils.cpp_extension import CppExtension, CUDAExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [CppExtension(name='cluster_cpu', sources=['cpu/cluster.cpp'])] ext_modules = [CppExtension('cluster_cpu', ['cpu/cluster.cpp'])]
if torch.cuda.is_available(): if torch.cuda.is_available():
ext_modules += [ ext_modules += [
CUDAExtension( CUDAExtension('cluster_cuda',
name='cluster_cuda', ['cuda/cluster.cpp'] + glob.glob('cuda/*.cu'))
sources=['cuda/cluster.cpp', 'cuda/cluster_kernel.cu'])
] ]
setup( setup(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment