Commit 55581cd6 authored by rusty1s's avatar rusty1s
Browse files

rename, outsource cpu methods

parent 645f3ddf
...@@ -8,8 +8,10 @@ from torch.utils.ffi import create_extension ...@@ -8,8 +8,10 @@ from torch.utils.ffi import create_extension
if osp.exists('build'): if osp.exists('build'):
shutil.rmtree('build') shutil.rmtree('build')
headers = ['torch_cluster/src/cpu.h'] files = ['serial', 'grid']
sources = ['torch_cluster/src/cpu.c']
headers = ['torch_cluster/src/{}_cpu.h'.format(f) for f in files]
sources = ['torch_cluster/src/{}_cpu.c'.format(f) for f in files]
include_dirs = ['torch_cluster/src'] include_dirs = ['torch_cluster/src']
define_macros = [] define_macros = []
extra_objects = [] extra_objects = []
......
from .functions.grid import sparse_grid_cluster, dense_grid_cluster
from .functions.serial import serial_cluster from .functions.serial import serial_cluster
from .functions.grid import sparse_grid_cluster, dense_grid_cluster
__version__ = '0.2.6' __version__ = '0.2.6'
__all__ = [ __all__ = [
'sparse_grid_cluster', 'dense_grid_cluster', 'serial_cluster', 'serial_cluster', 'sparse_grid_cluster', 'dense_grid_cluster',
'__version__' '__version__'
] ]
void cluster_grid_Float (int C, THLongTensor *output, THFloatTensor *position, THFloatTensor *size, THLongTensor *count);
void cluster_grid_Double(int C, THLongTensor *output, THDoubleTensor *position, THDoubleTensor *size, THLongTensor *count);
void cluster_grid_Byte (int C, THLongTensor *output, THByteTensor *position, THByteTensor *size, THLongTensor *count);
void cluster_grid_Char (int C, THLongTensor *output, THCharTensor *position, THCharTensor *size, THLongTensor *count);
void cluster_grid_Short (int C, THLongTensor *output, THShortTensor *position, THShortTensor *size, THLongTensor *count);
void cluster_grid_Int (int C, THLongTensor *output, THIntTensor *position, THIntTensor *size, THLongTensor *count);
void cluster_grid_Long (int C, THLongTensor *output, THLongTensor *position, THLongTensor *size, THLongTensor *count);
void cluster_random(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree);
#ifndef TH_GENERIC_FILE #ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/cpu.c" #define TH_GENERIC_FILE "generic/grid_cpu.c"
#else #else
void cluster_(grid)(int C, THLongTensor *output, THTensor *position, THTensor *size, THLongTensor *count) { void grid_cluster(int C, THLongTensor *output, THTensor *position, THTensor *size, THLongTensor *count) {
real *size_data = size->storage->data + size->storageOffset; real *size_data = size->storage->data + size->storageOffset;
int64_t *count_data = count->storage->data + count->storageOffset; int64_t *count_data = count->storage->data + count->storageOffset;
int64_t D, d, i, c, tmp; int64_t D, d, i, c, tmp;
......
#include <TH/TH.h>
#define grid_cluster TH_CONCAT_2(grid_cluster_, Real)
#include "generic/grid_cpu.c"
#include "THGenerateAllTypes.h"
void grid_cluster_Float (int C, THLongTensor *output, THFloatTensor *position, THFloatTensor *size, THLongTensor *count);
void grid_cluster_Double(int C, THLongTensor *output, THDoubleTensor *position, THDoubleTensor *size, THLongTensor *count);
void grid_cluster_Byte (int C, THLongTensor *output, THByteTensor *position, THByteTensor *size, THLongTensor *count);
void grid_cluster_Char (int C, THLongTensor *output, THCharTensor *position, THCharTensor *size, THLongTensor *count);
void grid_cluster_Short (int C, THLongTensor *output, THShortTensor *position, THShortTensor *size, THLongTensor *count);
void grid_cluster_Int (int C, THLongTensor *output, THIntTensor *position, THIntTensor *size, THLongTensor *count);
void grid_cluster_Long (int C, THLongTensor *output, THLongTensor *position, THLongTensor *size, THLongTensor *count);
#include <TH/TH.h> #include <TH/TH.h>
#define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _, Real) void serial_cluster(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree) {
void cluster_random(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree) {
int64_t *output_data = output->storage->data + output->storageOffset; int64_t *output_data = output->storage->data + output->storageOffset;
int64_t *row_data = row->storage->data + row->storageOffset; int64_t *row_data = row->storage->data + row->storageOffset;
int64_t *col_data = col->storage->data + col->storageOffset; int64_t *col_data = col->storage->data + col->storageOffset;
...@@ -40,5 +38,4 @@ void cluster_random(THLongTensor *output, THLongTensor *row, THLongTensor *col, ...@@ -40,5 +38,4 @@ void cluster_random(THLongTensor *output, THLongTensor *row, THLongTensor *col,
} }
} }
#include "generic/cpu.c"
#include "THGenerateAllTypes.h"
void serial_cluster(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree);
void serial_cluster_Float (THLongTensor *output, THLongTensor *row, THLongTensor *col, THFloatTensor *weight, THLongTensor *degree);
void serial_cluster_Double(THLongTensor *output, THLongTensor *row, THLongTensor *col, THDoubleTensor *weight, THLongTensor *degree);
void serial_cluster_Byte (THLongTensor *output, THLongTensor *row, THLongTensor *col, THByteTensor *weight, THLongTensor *degree);
void serial_cluster_Char (THLongTensor *output, THLongTensor *row, THLongTensor *col, THCharTensor *weight, THLongTensor *degree);
void serial_cluster_Short (THLongTensor *output, THLongTensor *row, THLongTensor *col, THShortTensor *weight, THLongTensor *degree);
void serial_cluster_Int (THLongTensor *output, THLongTensor *row, THLongTensor *col, THIntTensor *weight, THLongTensor *degree);
void serial_cluster_Long (THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *weight, THLongTensor *degree);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment