Commit fbd14a95 authored by rusty1s's avatar rusty1s
Browse files

setup for cuda and cpu

parent 4bc9a767
import torch
import cluster_cpu
import cluster_cuda
def grid_cluster(pos, size, start=None, end=None):
def grid(pos, size, start=None, end=None):
lib = cluster_cuda if pos.is_cuda else cluster_cpu
start = pos.t().min(dim=1)[0] if start is None else start
end = pos.t().max(dim=1)[0] if end is None else end
return cluster_cpu.grid(pos, size, start, end)
return lib.grid(pos, size, start, end)
def graclus_cluster(row, col, num_nodes):
def graclus(row, col, num_nodes):
return cluster_cpu.graclus(row, col, num_nodes)
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
size = torch.tensor([2, 2])
start = torch.tensor([0, 0])
end = torch.tensor([7, 7])
device = torch.device('cuda')
pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]], device=device)
size = torch.tensor([2, 2], device=device)
print('pos', pos.tolist())
print('size', size.tolist())
cluster = grid_cluster(pos, size)
print('result', cluster.tolist(), cluster.dtype)
cluster = grid(pos, size)
print('result', cluster.tolist(), cluster.dtype, cluster.device)
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
print(row)
print(col)
print('-----------------')
cluster = graclus_cluster(row, col, 4)
cluster = graclus(row, col, 4)
print(cluster)
#include <torch/torch.h>
#include "degree.cpp"
#include "loop.cpp"
#include "perm.cpp"
#include "../include/degree.cpp"
#include "../include/loop.cpp"
#include "../include/perm.cpp"
at::Tensor graclus(at::Tensor row, at::Tensor col, int num_nodes) {
std::tie(row, col) = remove_self_loops(row, col);
......
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name='cluster',
ext_modules=[CppExtension('cluster_cpu', ['cluster.cpp'])],
cmdclass={'build_ext': BuildExtension},
)
import torch
import cluster_cuda
dtype = torch.float
device = torch.device('cuda')
def grid_cluster(pos, size, start=None, end=None):
start = pos.t().min(dim=1)[0] if start is None else start
end = pos.t().max(dim=1)[0] if end is None else end
return cluster_cuda.grid(pos, size, start, end)
pos = torch.tensor(
[[1, 1], [3, 3], [5, 5], [7, 7]], dtype=dtype, device=device)
size = torch.tensor([2, 2, 1, 1, 4, 2, 1], dtype=dtype, device=device)
# print('pos', pos.tolist())
# print('size', size.tolist())
cluster = grid_cluster(pos, size)
print('result', cluster.tolist(), cluster.type())
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='cluster_cuda',
ext_modules=[
CUDAExtension('cluster_cuda', ['cluster.cpp', 'cluster_kernel.cu'])
],
cmdclass={'build_ext': BuildExtension},
)
#ifndef DEGREE_CPP
#define DEGREE_CPP
#include "degree.h"
#include <torch/torch.h>
......@@ -9,5 +8,3 @@ inline at::Tensor degree(at::Tensor index, int num_nodes,
auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one);
}
#endif // DEGREE_CPP
#ifndef DEGREE_INC
#define DEGREE_INC
#include <torch/torch.h>
inline at::Tensor degree(at::Tensor index, int num_nodes,
at::ScalarType scalar_type);
#endif // DEGREE_INC
#ifndef LOOP_CPP
#define LOOP_CPP
#include "loop.h"
#include <torch/torch.h>
......@@ -8,5 +7,3 @@ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
auto mask = row != col;
return {row.masked_select(mask), col.masked_select(mask)};
}
#endif // LOOP_CPP
#ifndef LOOP_INC
#define LOOP_INC
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
at::Tensor col);
#endif // LOOP_INC
#ifndef PERM_CPP
#define PERM_CPP
#include "perm.h"
#include <torch/torch.h>
......@@ -23,5 +22,3 @@ randperm(at::Tensor row, at::Tensor col, int num_nodes) {
return {row, col};
}
#endif // PERM_CPP
#ifndef PERM_INC
#define PERM_INC
#include <torch/torch.h>
inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes);
#endif // PERM_INC
import torch
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [CppExtension(name='cluster_cpu', sources=['cpu/cluster.cpp'])]
if torch.cuda.is_available():
ext_modules += [
CUDAExtension(
name='cluster_cuda',
sources=['cuda/cluster.cpp', 'cuda/cluster_kernel.cu'])
]
setup(
name='cluster',
ext_modules=ext_modules,
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment