Commit 12346b1b authored by rusty1s's avatar rusty1s
Browse files

allow multiple kernels

parent f2c8aa1e
......@@ -24,7 +24,7 @@ if torch.cuda.is_available():
sources += ['torch_cluster/src/{}_cuda.c'.format(f) for f in files]
include_dirs += ['torch_cluster/kernel']
define_macros += [('WITH_CUDA', None)]
extra_objects += ['torch_cluster/build/kernel.so']
extra_objects += ['torch_cluster/build/{}.so'.format(f) for f in files]
with_cuda = True
ffi = create_extension(
......
......@@ -7,4 +7,6 @@ SRC_DIR=torch_cluster/kernel
BUILD_DIR=torch_cluster/build
mkdir -p $BUILD_DIR
$(which nvcc) -c -o "$BUILD_DIR/kernel.so" "$SRC_DIR/kernel.cu" -arch=sm_35 -Xcompiler -fPIC -shared "-I$TORCH/lib/include/TH" "-I$TORCH/lib/include/THC" "-I$SRC_DIR"
for i in serial grid; do
$(which nvcc) -c -o "$BUILD_DIR/$i.so" "$SRC_DIR/$i.cu" -arch=sm_35 -Xcompiler -fPIC -shared "-I$TORCH/lib/include/TH" "-I$TORCH/lib/include/THC" "-I$SRC_DIR"
done
......@@ -5,4 +5,4 @@ description-file = README.md
test = pytest
[tool:pytest]
addopts = --capture=no --cov
addopts = --capture=no
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/kernel.cu"
#define THC_GENERIC_FILE "generic/grid.cu"
#else
void cluster_(grid)(THCState *state, int C, THCudaLongTensor *output, THCTensor *position, THCTensor *size, THCudaLongTensor *count) {
......@@ -28,3 +28,4 @@ void cluster_(grid)(THCState *state, int C, THCudaLongTensor *output, THCTensor
}
#endif
#include <THC.h>
#include "kernel.h"
#include "grid.h"
#include "common.cuh"
#include "THCIndex.cuh"
......@@ -24,17 +24,17 @@ __global__ void gridKernel(int64_t *output, TensorInfo<Real> position, Real *siz
}
}
#include "generic/kernel.cu"
#include "generic/grid.cu"
#include "THCGenerateFloatType.h"
#include "generic/kernel.cu"
#include "generic/grid.cu"
#include "THCGenerateDoubleType.h"
#include "generic/kernel.cu"
#include "generic/grid.cu"
#include "THCGenerateByteType.h"
#include "generic/kernel.cu"
#include "generic/grid.cu"
#include "THCGenerateCharType.h"
#include "generic/kernel.cu"
#include "generic/grid.cu"
#include "THCGenerateShortType.h"
#include "generic/kernel.cu"
#include "generic/grid.cu"
#include "THCGenerateIntType.h"
#include "generic/kernel.cu"
#include "generic/grid.cu"
#include "THCGenerateLongType.h"
#include <THC/THC.h>
#include "kernel.h"
#include "grid.h"
#define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _cuda_, Real)
#define cluster_kernel_(NAME) TH_CONCAT_4(cluster_, NAME, _kernel_, Real)
......
#include <THC/THC.h>
#include "kernel.h"
#include "serial.h"
#define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _cuda_, Real)
#define cluster_kernel_(NAME) TH_CONCAT_4(cluster_, NAME, _kernel_, Real)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment