Commit 8b5daa16 authored by rusty1s's avatar rusty1s
Browse files

gridkernel done

parent ad63397e
#include "THCGrid.h" #include "THCGrid.h"
template<typename real, int dims> #include "common.h"
__global__ void gridKernel(int64_t *cluster, TensorInfo<real> posInfo, real *size, #include "THCNumerics.cuh"
template<typename T>
__global__ void gridKernel(int64_t *cluster, TensorInfo<T> posInfo, T *size,
int64_t *count, const int nNodes) { int64_t *count, const int nNodes) {
KERNEL_LOOP(i, nNodes) { KERNEL_LOOP(i, nNodes) {
real *pos = posInfo->data + i * posInfo->stride[0]; T *pos = posInfo.data + i * posInfo.stride[0];
int64_t coef = 1, value = 0; int64_t coef = 1, value = 0;
for (ptrdiff_t d = 0; d < dims; d++) { for (ptrdiff_t d = 0; d < posInfo.dims * posInfo.stride[1]; d += posInfo.stride[1]) {
value += coef * (int64_t) (pos[d * posInfo->stride[1]] / size[d]); value += coef * THCNumerics<T>::floor(THCNumerics<T>::div(pos[d], size[d]));
coef *= count[d]; coef *= count[d];
} }
cluster[i] = value; cluster[i] = value;
......
#ifndef THC_NUMERICS_INC
#define THC_NUMERICS_INC
#include "THC/THCHalf.h"
template<typename T>
struct THCNumerics {
static inline __host__ __device__ T div(T a, T b) { return a / b; }
static inline __host__ __device__ int floor(T a) { return a; }
};
#ifdef CUDA_HALF_TENSOR
#ifdef __CUDA_ARCH__
#define h2f(A) __half2float(A)
#define f2h(A) __float2half(A)
#else // CUDA_ARCH__
#define h2f(A) THC_half2float(A)
#define f2h(A) THC_float2half(A)
#endif
template<>
struct THCNumerics<half> {
static inline __host__ __device__ half div(half a, half b) { return f2h(h2f(a) / h2f(b)); }
static inline __host__ __device__ int floor(half a) { return (int) h2f(a); }
};
#endif // CUDA_HALF_TENSOR
#endif // THC_NUMERICS_INC
#ifndef THC_COMMON_INC #ifndef THC_COMMON_INC
#define THC_COMMON_INC #define THC_COMMON_INC
#define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME)
#define KERNEL_LOOP(I, N) \ #define KERNEL_LOOP(I, N) \
for (ptrdiff_t I = blockIdx.x * blockDim.x + threadIdx.x; I < I; I += blockDim.x * gridDim.x) for (ptrdiff_t I = blockIdx.x * blockDim.x + threadIdx.x; I < N; I += blockDim.x * gridDim.x)
#define THC_assertSameGPU(...) THAssertMsg(THCTensor_(checkGPU)(__VA_ARGS__), \ #define THC_assertSameGPU(...) THAssertMsg(THCTensor_(checkGPU)(__VA_ARGS__), \
"Some of the input tensors are located on different GPUs. Please move them to a single one.") "Some of the input tensors are located on different GPUs. Please move them to a single one.")
const int CUDA_NUM_THREADS = 1024; const int MAX_DIMS = 25;
const int NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N) { inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; return (N + NUM_THREADS - 1) / NUM_THREADS;
} }
#define KERNEL_RUN(NAME, N, ...) \ #define KERNEL_RUN(NAME, N, ...) \
...@@ -19,16 +22,15 @@ inline int GET_BLOCKS(const int N) { ...@@ -19,16 +22,15 @@ inline int GET_BLOCKS(const int N) {
NAME<real><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \ NAME<real><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \
THCudaCheck(cudaGetLastError()) THCudaCheck(cudaGetLastError())
#define FIXED_DIM_KERNEL_RUN(NAME, N, DIMS, ...) \ template<typename T>
int grid = GET_BLOCKS(N); \ struct TensorInfo {
cudaStream_t stream = THCState_getCurrentStream(state); \ T *data;
switch (DIMS) { \ int dims;
case 1: NAME<real, 1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \ int size[MAX_DIMS];
case 2: NAME<real, 2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \ int stride[MAX_DIMS];
case 3: NAME<real, 3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \ };
case 4: NAME<real, 4><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \ #include "generic/common.h"
} \ #include "THC/THCGenerateAllTypes.h"
THCudaCheck(cudaGetLastError())
#endif // THC_COMMON_INC #endif // THC_COMMON_INC
...@@ -7,13 +7,12 @@ void THCGrid_(THCState *state, THCudaLongTensor *cluster, THCTensor *pos, THCTen ...@@ -7,13 +7,12 @@ void THCGrid_(THCState *state, THCudaLongTensor *cluster, THCTensor *pos, THCTen
THC_assertSameGPU(state, 4, cluster, pos, size, count); THC_assertSameGPU(state, 4, cluster, pos, size, count);
int64_t *clusterData = THCudaLongTensor_data(state, cluster); int64_t *clusterData = THCudaLongTensor_data(state, cluster);
TensorInfo<real> posInfo = THC_(getTensorInfo)(state, pos); TensorInfo<real> posInfo = THCTensor_(getTensorInfo)(state, pos);
real *sizeData = THCTensor_(data)(state, size); real *sizeData = THCTensor_(data)(state, size);
int64_t *countData = THCudaLongTensor_data(state, count); int64_t *countData = THCudaLongTensor_data(state, count);
const int nNodes = THCudaLongTensor_nElement(state, cluster); const int nNodes = THCudaLongTensor_nElement(state, cluster);
const int dims = THCTensor_(nElement)(size); KERNEL_RUN(gridKernel, nNodes, clusterData, posInfo, sizeData, countData);
FIXED_DIM_KERNEL_RUN(gridKernel, nNodes, dims, clusterData, posInfo, sizeData, countData);
} }
#endif // THC_GENERIC_FILE #endif // THC_GENERIC_FILE
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/common.h"
#else
TensorInfo<real> THCTensor_(getTensorInfo)(THCState *state, THCTensor *tensor) {
TensorInfo<real> tensorInfo = TensorInfo<real>();
tensorInfo.data = THCTensor_(data)(state, tensor);
tensorInfo.dims = THCTensor_(nDimension)(state, tensor);
for (ptrdiff_t d = 0; d < tensorInfo.dims; d++) {
tensorInfo.size[d] = THCTensor_(size)(state, tensor, d);
tensorInfo.stride[d] = THCTensor_(stride)(state, tensor, d);
}
return tensorInfo;
}
#endif // THC_GENERIC_FILE
...@@ -2,9 +2,9 @@ import torch ...@@ -2,9 +2,9 @@ import torch
from torch_cluster._ext import ffi from torch_cluster._ext import ffi
cluster = torch.cuda.LongTensor(5) cluster = torch.cuda.LongTensor(5)
pos = torch.cuda.FloatTensor(5, 2) pos = torch.cuda.FloatTensor([[1, 1], [3, 3], [1, 1], [5, 5], [3, 3]])
size = torch.cuda.FloatTensor(2) size = torch.cuda.FloatTensor([2, 2])
count = torch.cuda.LongTensor(2) count = torch.cuda.LongTensor([3, 3])
func = ffi.THCCFloatGrid func = ffi.THCCFloatGrid
print(func) print(func)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment