"tools/vscode:/vscode.git/clone" did not exist on "dff4b320c4c5efe0984cfa90d3b02029f81809ff"
Commit e55d73ef authored by rusty1s's avatar rusty1s
Browse files

cleanup

parent a487a44a
include LICENSE
include cpu/scatter.cpp
include cpu/dim_apply.cpp
recursive-include cpu *
recursive-include cuda *
import os.path as osp
import subprocess
import torch
from torch.utils.ffi import create_extension
headers = ['torch_scatter/src/cpu.h']
sources = ['torch_scatter/src/cpu.c']
include_dirs = ['torch_scatter/src']
define_macros = []
extra_objects = []
extra_compile_args = ['-std=c99']
with_cuda = False
if torch.cuda.is_available():
subprocess.call(['./build.sh', osp.dirname(torch.__file__)])
headers += ['torch_scatter/src/gpu.h']
sources += ['torch_scatter/src/gpu.c']
include_dirs += ['torch_scatter/kernel']
define_macros += [('WITH_CUDA', None)]
extra_objects += ['torch_scatter/build/kernel.so']
with_cuda = True
ffi = create_extension(
name='torch_scatter._ext.ffi',
package=True,
headers=headers,
sources=sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_objects=extra_objects,
extra_compile_args=extra_compile_args,
with_cuda=with_cuda,
relative_to=__file__)
if __name__ == '__main__':
ffi.build()
#!/bin/sh
echo "Compiling kernel..."
if [ -z "$1" ]; then TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))"); else TORCH="$1"; fi
SRC_DIR=torch_scatter/kernel
BUILD_DIR=torch_scatter/build
mkdir -p $BUILD_DIR
$(which nvcc) "-I$TORCH/lib/include" "-I$TORCH/lib/include/TH" "-I$TORCH/lib/include/THC" "-I$SRC_DIR" -c "$SRC_DIR/kernel.cu" -o "$BUILD_DIR/kernel.so" --compiler-options '-fPIC' -std=c++11
import glob
from setuptools import setup, find_packages
import torch.cuda
import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [
......@@ -14,7 +12,7 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if torch.cuda.is_available():
ext_modules += [
CUDAExtension('scatter_cuda',
['cuda/scatter.cpp'] + glob.glob('cuda/*.cu'))
['cuda/scatter.cpp', 'cuda/scatter_kernel.cu'])
]
__version__ = '1.0.4'
......
#define TH_TENSOR_DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, TENSOR4, DIMENSION, CODE) { \
TYPE1 *TENSOR1##_data = NULL; \
int64_t TENSOR1##_stride = 0, TENSOR1##_size = 0; \
TYPE2 *TENSOR2##_data = NULL; \
int64_t TENSOR2##_stride = 0, TENSOR2##_size = 0; \
TYPE3 *TENSOR3##_data = NULL; \
int64_t TENSOR3##_stride = 0, TENSOR3##_size = 0; \
TYPE4 *TENSOR4##_data = NULL; \
int64_t TENSOR4##_stride = 0, TENSOR4##_size = 0; \
\
int64_t *TH_TENSOR_DIM_APPLY_counter = NULL; \
int TH_TENSOR_DIM_APPLY_hasFinished = 0; \
int TH_TENSOR_DIM_APPLY_i; \
\
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(TENSOR1->nDimension)); \
\
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) { \
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
} \
\
TENSOR1##_data = (TENSOR1)->storage->data+(TENSOR1)->storageOffset; \
TENSOR1##_stride = (TENSOR1)->stride[DIMENSION]; \
TENSOR1##_size = TENSOR1->size[DIMENSION]; \
\
TENSOR2##_data = (TENSOR2)->storage->data+(TENSOR2)->storageOffset; \
TENSOR2##_stride = (TENSOR2)->stride[DIMENSION]; \
TENSOR2##_size = TENSOR2->size[DIMENSION]; \
\
TENSOR3##_data = (TENSOR3)->storage->data+(TENSOR3)->storageOffset; \
TENSOR3##_stride = (TENSOR3)->stride[DIMENSION]; \
TENSOR3##_size = TENSOR3->size[DIMENSION]; \
\
TENSOR4##_data = (TENSOR4)->storage->data+(TENSOR4)->storageOffset; \
TENSOR4##_stride = (TENSOR4)->stride[DIMENSION]; \
TENSOR4##_size = TENSOR4->size[DIMENSION]; \
\
while (!TH_TENSOR_DIM_APPLY_hasFinished) { \
CODE \
\
if (TENSOR1->nDimension == 1) break; \
\
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) { \
if (TH_TENSOR_DIM_APPLY_i == DIMENSION) { \
if (TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) { \
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
break; \
} \
continue; \
} \
\
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]++; \
TENSOR1##_data += TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR2##_data += TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR3##_data += TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR4##_data += TENSOR4->stride[TH_TENSOR_DIM_APPLY_i]; \
\
if (TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] == TENSOR1->size[TH_TENSOR_DIM_APPLY_i]) { \
if (TH_TENSOR_DIM_APPLY_i == TENSOR1->nDimension-1) { \
TH_TENSOR_DIM_APPLY_hasFinished = 1; \
break; \
} \
else { \
TENSOR1##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR1->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR2##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR2->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR3##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR3->stride[TH_TENSOR_DIM_APPLY_i]; \
TENSOR4##_data -= TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i]*TENSOR4->stride[TH_TENSOR_DIM_APPLY_i]; \
TH_TENSOR_DIM_APPLY_counter[TH_TENSOR_DIM_APPLY_i] = 0; \
} \
} \
else break; \
} \
} \
THFree(TH_TENSOR_DIM_APPLY_counter); \
}
#include <TH/TH.h>
#include "THTensorDimApply4.h"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
#define index_backward TH_CONCAT_2(index_backward_, Real)
inline void assertIndexInBoundaries(int idx, int size, int64_t *free) {
if (idx < 0 || idx >= size) { THFree(free); THError("Invalid index"); }
}
#include "generic/cpu.c"
#include "THGenerateAllTypes.h"
void scatter_mul_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_mul_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_mul_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_mul_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_mul_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_mul_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_mul_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_div_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input);
void scatter_div_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input);
void scatter_div_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input);
void scatter_div_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input);
void scatter_div_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input);
void scatter_div_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input);
void scatter_div_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input);
void scatter_mean_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THFloatTensor *count);
void scatter_mean_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THDoubleTensor *count);
void scatter_mean_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THByteTensor *count);
void scatter_mean_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THCharTensor *count);
void scatter_mean_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THShortTensor *count);
void scatter_mean_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THIntTensor *count);
void scatter_mean_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *count);
void scatter_max_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *arg);
void scatter_max_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *arg);
void scatter_max_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *arg);
void scatter_max_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *arg);
void scatter_max_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *arg);
void scatter_max_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *arg);
void scatter_max_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *arg);
void scatter_min_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *input, THLongTensor *arg);
void scatter_min_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *input, THLongTensor *arg);
void scatter_min_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *input, THLongTensor *arg);
void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *input, THLongTensor *arg);
void scatter_min_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *input, THLongTensor *arg);
void scatter_min_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *input, THLongTensor *arg);
void scatter_min_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *input, THLongTensor *arg);
void index_backward_Float (int dim, THFloatTensor *output, THLongTensor *index, THFloatTensor *grad, THLongTensor *arg);
void index_backward_Double(int dim, THDoubleTensor *output, THLongTensor *index, THDoubleTensor *grad, THLongTensor *arg);
void index_backward_Byte (int dim, THByteTensor *output, THLongTensor *index, THByteTensor *grad, THLongTensor *arg);
void index_backward_Char (int dim, THCharTensor *output, THLongTensor *index, THCharTensor *grad, THLongTensor *arg);
void index_backward_Short (int dim, THShortTensor *output, THLongTensor *index, THShortTensor *grad, THLongTensor *arg);
void index_backward_Int (int dim, THIntTensor *output, THLongTensor *index, THIntTensor *grad, THLongTensor *arg);
void index_backward_Long (int dim, THLongTensor *output, THLongTensor *index, THLongTensor *grad, THLongTensor *arg);
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/cpu.c"
#else
void scatter_(mul)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx * output_stride] *= *(input_data + i * input_stride);
})
}
void scatter_(div)(int dim, THTensor *output, THLongTensor *index, THTensor *input) {
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY3(real, output, int64_t, index, real, input, dim, TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx * output_stride] /= *(input_data + i * input_stride);
})
}
void scatter_(mean)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THTensor *count) {
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, real, count, dim,
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
output_data[idx * output_stride] += *(input_data + i * input_stride);
count_data[idx * count_stride]++;
})
}
void scatter_(max)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
if (*(input_data + i * input_stride) >= *(output_data + idx * output_stride)) {
output_data[idx * output_stride] = *(input_data + i * input_stride);
arg_data[idx * arg_stride] = i;
}
})
}
void scatter_(min)(int dim, THTensor *output, THLongTensor *index, THTensor *input, THLongTensor *arg) {
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, input, int64_t, arg, dim,
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
assertIndexInBoundaries(idx, output_size, TH_TENSOR_DIM_APPLY_counter);
if (*(input_data + i * input_stride) <= *(output_data + idx * output_stride)) {
output_data[idx * output_stride] = *(input_data + i * input_stride);
arg_data[idx * arg_stride] = i;
}
})
}
void index_backward(int dim, THTensor *output, THLongTensor *index, THTensor *grad, THLongTensor *arg) {
int64_t n, i, idx;
n = THLongTensor_size(index, dim);
TH_TENSOR_DIM_APPLY4(real, output, int64_t, index, real, grad, int64_t, arg, dim,
for (i = 0; i < n; i++) {
idx = *(index_data + i * index_stride);
if (*(arg_data + idx * arg_stride) == i) output_data[i * output_stride] = *(grad_data + idx * grad_stride);
})
}
#endif
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/cuda.c"
#else
void scatter_(mul)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
scatter_kernel_(mul)(state, dim, output, index, input);
}
void scatter_(div)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input) {
scatter_kernel_(div)(state, dim, output, index, input);
}
void scatter_(mean)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCTensor *count) {
scatter_kernel_(mean)(state, dim, output, index, input, count);
}
void scatter_(max)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
scatter_kernel_(max)(state, dim, output, index, input, arg);
}
void scatter_(min)(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *input, THCudaLongTensor *arg) {
scatter_kernel_(min)(state, dim, output, index, input, arg);
}
void index_backward(int dim, THCTensor *output, THCudaLongTensor *index, THCTensor *grad, THCudaLongTensor *arg) {
index_backward_kernel(state, dim, output, index, grad, arg);
}
#endif
#include <THC/THC.h>
#include "kernel.h"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _cuda_, Real)
#define scatter_kernel_(NAME) TH_CONCAT_4(scatter_, NAME, _kernel_, Real)
#define index_backward TH_CONCAT_2(index_backward_cuda_, Real)
#define index_backward_kernel TH_CONCAT_2(index_backward_kernel_, Real)
extern THCState *state;
#include "generic/cuda.c"
#include "THCGenerateFloatType.h"
#include "generic/cuda.c"
#include "THCGenerateDoubleType.h"
#include "generic/cuda.c"
#include "THCGenerateByteType.h"
#include "generic/cuda.c"
#include "THCGenerateCharType.h"
#include "generic/cuda.c"
#include "THCGenerateShortType.h"
#include "generic/cuda.c"
#include "THCGenerateIntType.h"
#include "generic/cuda.c"
#include "THCGenerateLongType.h"
void scatter_mul_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input);
void scatter_mul_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input);
void scatter_mul_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input);
void scatter_mul_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input);
void scatter_mul_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input);
void scatter_mul_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input);
void scatter_mul_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input);
void scatter_div_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input);
void scatter_div_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input);
void scatter_div_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input);
void scatter_div_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input);
void scatter_div_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input);
void scatter_div_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input);
void scatter_div_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input);
void scatter_mean_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaTensor *count);
void scatter_mean_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaDoubleTensor *count);
void scatter_mean_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaByteTensor *count);
void scatter_mean_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaCharTensor *count);
void scatter_mean_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaShortTensor *count);
void scatter_mean_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaIntTensor *count);
void scatter_mean_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *count);
void scatter_max_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaLongTensor *arg);
void scatter_max_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaLongTensor *arg);
void scatter_max_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaLongTensor *arg);
void scatter_max_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaLongTensor *arg);
void scatter_max_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaLongTensor *arg);
void scatter_max_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaLongTensor *arg);
void scatter_max_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *arg);
void scatter_min_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *input, THCudaLongTensor *arg);
void scatter_min_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *input, THCudaLongTensor *arg);
void scatter_min_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *input, THCudaLongTensor *arg);
void scatter_min_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *input, THCudaLongTensor *arg);
void scatter_min_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *input, THCudaLongTensor *arg);
void scatter_min_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *input, THCudaLongTensor *arg);
void scatter_min_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *input, THCudaLongTensor *arg);
void index_backward_cuda_Float (int dim, THCudaTensor *output, THCudaLongTensor *index, THCudaTensor *grad, THCudaLongTensor *arg);
void index_backward_cuda_Double(int dim, THCudaDoubleTensor *output, THCudaLongTensor *index, THCudaDoubleTensor *grad, THCudaLongTensor *arg);
void index_backward_cuda_Byte (int dim, THCudaByteTensor *output, THCudaLongTensor *index, THCudaByteTensor *grad, THCudaLongTensor *arg);
void index_backward_cuda_Char (int dim, THCudaCharTensor *output, THCudaLongTensor *index, THCudaCharTensor *grad, THCudaLongTensor *arg);
void index_backward_cuda_Short (int dim, THCudaShortTensor *output, THCudaLongTensor *index, THCudaShortTensor *grad, THCudaLongTensor *arg);
void index_backward_cuda_Int (int dim, THCudaIntTensor *output, THCudaLongTensor *index, THCudaIntTensor *grad, THCudaLongTensor *arg);
void index_backward_cuda_Long (int dim, THCudaLongTensor *output, THCudaLongTensor *index, THCudaLongTensor *grad, THCudaLongTensor *arg);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment