Commit 653e96d9 authored by rusty1s's avatar rusty1s
Browse files

outsourced cuda code

parent b9a7f326
...@@ -20,8 +20,8 @@ with_cuda = False ...@@ -20,8 +20,8 @@ with_cuda = False
if torch.cuda.is_available(): if torch.cuda.is_available():
subprocess.call(['./build.sh', osp.dirname(torch.__file__)]) subprocess.call(['./build.sh', osp.dirname(torch.__file__)])
headers += ['torch_cluster/src/cuda.h'] headers += ['torch_cluster/src/{}_cuda.h'.format(f) for f in files]
sources += ['torch_cluster/src/cuda.c'] sources += ['torch_cluster/src/{}_cuda.c'.format(f) for f in files]
include_dirs += ['torch_cluster/kernel'] include_dirs += ['torch_cluster/kernel']
define_macros += [('WITH_CUDA', None)] define_macros += [('WITH_CUDA', None)]
extra_objects += ['torch_cluster/build/kernel.so'] extra_objects += ['torch_cluster/build/kernel.so']
......
#ifndef THC_GENERIC_FILE #ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/cuda.c" #define THC_GENERIC_FILE "generic/grid_cuda.c"
#else #else
void cluster_(grid)(int C, THCudaLongTensor *output, THCTensor *position, THCTensor *size, THCudaLongTensor *count) { void cluster_(grid)(int C, THCudaLongTensor *output, THCTensor *position, THCTensor *size, THCudaLongTensor *count) {
......
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/serial_cuda.c"
#else
void cluster_(serial)(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCTensor *weight, THCudaLongTensor *degree) {
}
#endif
...@@ -7,17 +7,17 @@ ...@@ -7,17 +7,17 @@
extern THCState *state; extern THCState *state;
#include "generic/cuda.c" #include "generic/grid_cuda.c"
#include "THCGenerateFloatType.h" #include "THCGenerateFloatType.h"
#include "generic/cuda.c" #include "generic/grid_cuda.c"
#include "THCGenerateDoubleType.h" #include "THCGenerateDoubleType.h"
#include "generic/cuda.c" #include "generic/grid_cuda.c"
#include "THCGenerateByteType.h" #include "THCGenerateByteType.h"
#include "generic/cuda.c" #include "generic/grid_cuda.c"
#include "THCGenerateCharType.h" #include "THCGenerateCharType.h"
#include "generic/cuda.c" #include "generic/grid_cuda.c"
#include "THCGenerateShortType.h" #include "THCGenerateShortType.h"
#include "generic/cuda.c" #include "generic/grid_cuda.c"
#include "THCGenerateIntType.h" #include "THCGenerateIntType.h"
#include "generic/cuda.c" #include "generic/grid_cuda.c"
#include "THCGenerateLongType.h" #include "THCGenerateLongType.h"
#include <THC/THC.h>
#include "kernel.h"
#define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _cuda_, Real)
#define cluster_kernel_(NAME) TH_CONCAT_4(cluster_, NAME, _kernel_, Real)
extern THCState *state;
void cluster_serial_cuda(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree) {
}
#include "generic/serial_cuda.c"
#include "THCGenerateFloatType.h"
#include "generic/serial_cuda.c"
#include "THCGenerateDoubleType.h"
#include "generic/serial_cuda.c"
#include "THCGenerateByteType.h"
#include "generic/serial_cuda.c"
#include "THCGenerateCharType.h"
#include "generic/serial_cuda.c"
#include "THCGenerateShortType.h"
#include "generic/serial_cuda.c"
#include "THCGenerateIntType.h"
#include "generic/serial_cuda.c"
#include "THCGenerateLongType.h"
void cluster_serial_cuda(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *degree);
void cluster_serial_cuda_Float (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaTensor *weight, THCudaLongTensor *degree);
void cluster_serial_cuda_Double(THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaDoubleTensor *weight, THCudaLongTensor *degree);
void cluster_serial_cuda_Byte (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaByteTensor *weight, THCudaLongTensor *degree);
void cluster_serial_cuda_Char (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaCharTensor *weight, THCudaLongTensor *degree);
void cluster_serial_cuda_Short (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaShortTensor *weight, THCudaLongTensor *degree);
void cluster_serial_cuda_Int (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaIntTensor *weight, THCudaLongTensor *degree);
void cluster_serial_cuda_Long (THCudaLongTensor *output, THCudaLongTensor *row, THCudaLongTensor *col, THCudaLongTensor *weight, THCudaLongTensor *degree);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment