Commit 05877042 authored by rusty1s's avatar rusty1s
Browse files

cleanup

parent 672358fd
...@@ -13,14 +13,14 @@ def graclus_cluster(row, col, num_nodes): ...@@ -13,14 +13,14 @@ def graclus_cluster(row, col, num_nodes):
return cluster_cpu.graclus(row, col, num_nodes) return cluster_cpu.graclus(row, col, num_nodes)
# pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]]) pos = torch.tensor([[1, 1], [3, 3], [5, 5], [7, 7]])
# size = torch.tensor([2, 2]) size = torch.tensor([2, 2])
# start = torch.tensor([0, 0]) start = torch.tensor([0, 0])
# end = torch.tensor([7, 7]) end = torch.tensor([7, 7])
# print('pos', pos.tolist()) print('pos', pos.tolist())
# print('size', size.tolist()) print('size', size.tolist())
# cluster = grid_cluster(pos, size) cluster = grid_cluster(pos, size)
# print('result', cluster.tolist(), cluster.dtype) print('result', cluster.tolist(), cluster.dtype)
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3]) row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2]) col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
inline at::Tensor degree(at::Tensor index, int num_nodes, inline at::Tensor degree(at::Tensor index, int num_nodes,
at::ScalarType scalar_type) { at::ScalarType scalar_type) {
auto zero = at::full(torch::CPU(scalar_type), {num_nodes}, 0); auto zero = at::full(index.type().toScalarType(scalar_type), {num_nodes}, 0);
auto one = at::full(zero.type(), {index.size(0)}, 1); auto one = at::full(zero.type(), {index.size(0)}, 1);
return zero.scatter_add_(0, index, one); return zero.scatter_add_(0, index, one);
} }
......
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
inline std::tuple<at::Tensor, at::Tensor> inline std::tuple<at::Tensor, at::Tensor>
randperm(at::Tensor row, at::Tensor col, int num_nodes) { randperm(at::Tensor row, at::Tensor col, int num_nodes) {
// Randomly reorder row and column indices. // Randomly reorder row and column indices.
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0)); auto perm = at::randperm(row.type(), row.size(0));
row = row.index_select(0, perm); row = row.index_select(0, perm);
col = col.index_select(0, perm); col = col.index_select(0, perm);
// Randomly swap row values. // Randomly swap row values.
auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes); auto node_rid = at::randperm(row.type(), num_nodes);
row = node_rid.index_select(0, row); row = node_rid.index_select(0, row);
// Sort row and column indices row-wise. // Sort row and column indices row-wise.
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t> template <typename scalar_t>
__global__ void grid_cuda_kernel( __global__ void
int64_t *cluster, at::cuda::detail::TensorInfo<scalar_t, int> pos, grid_cuda_kernel(int64_t *cluster,
at::cuda::detail::TensorInfo<scalar_t, int> pos,
scalar_t *__restrict__ size, scalar_t *__restrict__ start, scalar_t *__restrict__ size, scalar_t *__restrict__ start,
scalar_t *__restrict__ end, size_t num_nodes) { scalar_t *__restrict__ end, size_t num_nodes) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -26,7 +27,7 @@ __global__ void grid_cuda_kernel( ...@@ -26,7 +27,7 @@ __global__ void grid_cuda_kernel(
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) { at::Tensor end) {
auto num_nodes = pos.size(0); auto num_nodes = pos.size(0);
auto cluster = at::empty(pos.type().toType(at::kLong), {num_nodes}); auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] { AT_DISPATCH_ALL_TYPES(pos.type(), "grid_cuda_kernel", [&] {
grid_cuda_kernel<scalar_t><<<BLOCKS(num_nodes), THREADS>>>( grid_cuda_kernel<scalar_t><<<BLOCKS(num_nodes), THREADS>>>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment