Commit ba9f2ed2 authored by rusty1s's avatar rusty1s
Browse files

added grid cuda implementation

parent 26f5fa37
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end);
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
CHECK_CUDA(start);
CHECK_CUDA(end);
return grid_cuda(pos, size, start, end);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid", &grid, "Grid (CUDA)");
}
#include "grid_cpu.h"
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
const scalar_t *start, const scalar_t *end,
int64_t *out, int64_t N, int64_t D, int64_t numel) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t c = 0, k = 1;
for (int64_t d = 0; d < D; d++) {
scalar_t p = pos.data[thread_idx * D + d] - start[d];
c += (int64_t)(p / size[d]) * k;
k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
}
out[thread_idx] = c;
}
}
torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) {
CHECK_CUDA(pos);
CHECK_CUDA(size);
cudaSetDevice(pos.get_device());
if (optional_start.has_value())
CHECK_CPU(optional_start.value());
if (optional_start.has_value())
CHECK_CPU(optional_start.value());
pos = pos.view({pos.size(0), -1}).contiguous();
size = size.contiguous();
CHECK_INPUT(size.numel() == pos.size(1));
if (!optional_start.has_value())
optional_start = std::get<0>(pos.min(0));
else {
optional_start = optional_start.value().contiguous();
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
}
if (!optional_end.has_value())
optional_end = std::get<0>(pos.max(0));
else {
optional_start = optional_start.value().contiguous();
CHECK_INPUT(optional_start.value().numel() == pos.size(1));
}
auto start = optional_start.value();
auto end = optional_end.value();
auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong));
AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] {
grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS>>>(
pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
out.data_ptr<int64_t>(), pos.size(0), pos.size(1), out.numel());
});
return out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end);
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void grid_kernel(int64_t *cluster,
at::cuda::detail::TensorInfo<scalar_t, int64_t> pos,
scalar_t *__restrict__ size,
scalar_t *__restrict__ start,
scalar_t *__restrict__ end, size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
int64_t c = 0, k = 1;
for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
scalar_t p = pos.data[i * pos.strides[0] + d * pos.strides[1]] - start[d];
c += (int64_t)(p / size[d]) * k;
k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
}
cluster[i] = c;
}
}
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
at::Tensor end) {
cudaSetDevice(pos.get_device());
auto cluster = at::empty(pos.size(0), pos.options().dtype(at::kLong));
AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] {
grid_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
cluster.DATA_PTR<int64_t>(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(pos),
size.DATA_PTR<scalar_t>(), start.DATA_PTR<scalar_t>(),
end.DATA_PTR<scalar_t>(), cluster.numel());
});
return cluster;
}
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
////////////////////////////////////////////////////////////////////////
#include <ATen/ATen.h>
std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment