Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
3bed6293
Commit
3bed6293
authored
Dec 20, 2017
by
rusty1s
Browse files
more atomic operations, please
parent
b88d5613
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
36 deletions
+78
-36
torch_scatter/kernel/THCAtomics.cuh
torch_scatter/kernel/THCAtomics.cuh
+45
-0
torch_scatter/kernel/generic/kernel.cu
torch_scatter/kernel/generic/kernel.cu
+5
-8
torch_scatter/kernel/kernel.h
torch_scatter/kernel/kernel.h
+28
-28
No files found.
torch_scatter/kernel/THCAtomics.cuh
View file @
3bed6293
...
@@ -98,6 +98,41 @@ struct AtomicDecimalImpl<T, 8> {
...
@@ -98,6 +98,41 @@ struct AtomicDecimalImpl<T, 8> {
}
}
};
};
static
inline
__device__
void
atomicAdd
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicAdd
(
int8_t
*
address
,
int8_t
val
)
{
AtomicIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicAdd
(
int16_t
*
address
,
int16_t
val
)
{
AtomicIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicAdd
(
int64_t
*
address
,
int64_t
val
)
{
AtomicIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static
inline
__device__
void
atomicAdd
(
double
*
address
,
double
val
)
{
AtomicDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#elif !defined(__CUDA_ARCH__) && (CUDA_VERSION < 8000)
static
inline
__device__
void
atomicAdd
(
double
*
address
,
double
val
)
{}
#endif
#ifdef CUDA_HALF_TENSOR
static
inline
__device__
void
atomicAdd
(
half
*
address
,
half
val
)
{}
#endif
static
inline
__device__
void
atomicMul
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMul
(
int8_t
*
address
,
int8_t
val
)
{
AtomicIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMul
(
int16_t
*
address
,
int16_t
val
)
{
AtomicIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMul
(
int32_t
*
address
,
int32_t
val
)
{
AtomicIntegerImpl
<
int32_t
,
sizeof
(
int32_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMul
(
int64_t
*
address
,
int64_t
val
)
{
AtomicIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMul
(
float
*
address
,
float
val
)
{
AtomicDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMul
(
double
*
address
,
double
val
)
{
AtomicDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#ifdef CUDA_HALF_TENSOR
static
inline
__device__
void
atomicMul
(
half
*
address
,
half
val
)
{}
#endif
static
inline
__device__
void
atomicDiv
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicDiv
(
int8_t
*
address
,
int8_t
val
)
{
AtomicIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicDiv
(
int16_t
*
address
,
int16_t
val
)
{
AtomicIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicDiv
(
int32_t
*
address
,
int32_t
val
)
{
AtomicIntegerImpl
<
int32_t
,
sizeof
(
int32_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicDiv
(
int64_t
*
address
,
int64_t
val
)
{
AtomicIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicDiv
(
float
*
address
,
float
val
)
{
AtomicDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicDiv
(
double
*
address
,
double
val
)
{
AtomicDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#ifdef CUDA_HALF_TENSOR
static
inline
__device__
void
atomicDiv
(
half
*
address
,
half
val
)
{}
#endif
static
inline
__device__
void
atomicMax
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMax
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMax
(
int8_t
*
address
,
int8_t
val
)
{
AtomicIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMax
(
int8_t
*
address
,
int8_t
val
)
{
AtomicIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMax
(
int16_t
*
address
,
int16_t
val
)
{
AtomicIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMax
(
int16_t
*
address
,
int16_t
val
)
{
AtomicIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
...
@@ -107,3 +142,13 @@ static inline __device__ void atomicMax( double *address, double val) { AtomicD
...
@@ -107,3 +142,13 @@ static inline __device__ void atomicMax( double *address, double val) { AtomicD
#ifdef CUDA_HALF_TENSOR
#ifdef CUDA_HALF_TENSOR
static
inline
__device__
void
atomicMax
(
half
*
address
,
half
val
)
{}
static
inline
__device__
void
atomicMax
(
half
*
address
,
half
val
)
{}
#endif
#endif
static
inline
__device__
void
atomicMin
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMin
(
int8_t
*
address
,
int8_t
val
)
{
AtomicIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMin
(
int16_t
*
address
,
int16_t
val
)
{
AtomicIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMin
(
int64_t
*
address
,
int64_t
val
)
{
AtomicIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMin
(
float
*
address
,
float
val
)
{
AtomicDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMin
(
double
*
address
,
double
val
)
{
AtomicDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#ifdef CUDA_HALF_TENSOR
static
inline
__device__
void
atomicMin
(
half
*
address
,
half
val
)
{}
#endif
torch_scatter/kernel/generic/kernel.cu
View file @
3bed6293
...
@@ -12,33 +12,30 @@ void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor
...
@@ -12,33 +12,30 @@ void scatter_(div)(THCState *state, int dim, THCTensor *output, THCudaLongTensor
printf
(
"div"
);
printf
(
"div"
);
}
}
void
scatter_
(
mean
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCTensor
*
num_outpu
t
)
{
void
scatter_
(
mean
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCTensor
*
coun
t
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
thc_
(
check
)(
state
,
output
,
index
,
input
);
printf
(
"mean"
);
printf
(
"mean"
);
}
}
void
scatter_
(
max
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg
_output
)
{
void
scatter_
(
max
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
thc_
(
check
)(
state
,
output
,
index
,
input
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
int64_t
>
argInfo
=
thc_getTensorInfo_Long
(
state
,
arg
_output
);
TensorInfo
<
int64_t
>
argInfo
=
thc_getTensorInfo_Long
(
state
,
arg
);
KERNEL_RUN
(
maxKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
argInfo
,
dim
)
KERNEL_RUN
(
maxKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
argInfo
,
dim
)
/* KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, dim) */
/* KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, dim) */
/* maxKernel<real, -1><<<GET_BLOCKS(n), NUM_THREADS, 0, THCState_getCurrentStream(state)>>>(outputInfo, indexInfo, inputInfo, dim, n); */
/* argKernel<real, -1><<<GET_BLOCKS(n), NUM_THREADS, 0, THCState_getCurrentStream(state)>>>(dim, n); */
}
}
void
scatter_
(
min
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg
_output
)
{
void
scatter_
(
min
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
thc_
(
check
)(
state
,
output
,
index
,
input
);
printf
(
"min"
);
printf
(
"min"
);
}
}
void
index_backward
(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
grad
,
THCudaLongTensor
*
arg
_grad
)
{
void
index_backward
(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
grad
,
THCudaLongTensor
*
arg
)
{
thc_
(
check
)(
state
,
output
,
index
,
grad
);
thc_
(
check
)(
state
,
output
,
index
,
grad
);
printf
(
"index_backward"
);
printf
(
"index_backward"
);
}
}
...
...
torch_scatter/kernel/kernel.h
View file @
3bed6293
...
@@ -18,37 +18,37 @@ void scatter_div_kernel_Short (THCState *state, int dim, THCudaShortTensor *out
...
@@ -18,37 +18,37 @@ void scatter_div_kernel_Short (THCState *state, int dim, THCudaShortTensor *out
void
scatter_div_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
);
void
scatter_div_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
);
void
scatter_div_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
);
void
scatter_div_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
);
void
scatter_mean_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaTensor
*
num_outpu
t
);
void
scatter_mean_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaTensor
*
coun
t
);
void
scatter_mean_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
num_outpu
t
);
void
scatter_mean_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
coun
t
);
void
scatter_mean_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaByteTensor
*
num_outpu
t
);
void
scatter_mean_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaByteTensor
*
coun
t
);
void
scatter_mean_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaCharTensor
*
num_outpu
t
);
void
scatter_mean_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaCharTensor
*
coun
t
);
void
scatter_mean_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaShortTensor
*
num_outpu
t
);
void
scatter_mean_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaShortTensor
*
coun
t
);
void
scatter_mean_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaIntTensor
*
num_outpu
t
);
void
scatter_mean_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaIntTensor
*
coun
t
);
void
scatter_mean_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
num_outpu
t
);
void
scatter_mean_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
coun
t
);
void
scatter_max_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_max_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_max_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_max_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_max_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_max_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_max_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_max_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_min_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_min_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_min_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_min_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_min_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_min_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
arg
_output
);
void
scatter_min_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
grad
,
THCudaLongTensor
*
arg
_grad
);
void
index_backward_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
grad
,
THCudaLongTensor
*
arg
_grad
);
void
index_backward_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
grad
,
THCudaLongTensor
*
arg
_grad
);
void
index_backward_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
grad
,
THCudaLongTensor
*
arg
_grad
);
void
index_backward_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
grad
,
THCudaLongTensor
*
arg
_grad
);
void
index_backward_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
grad
,
THCudaLongTensor
*
arg
_grad
);
void
index_backward_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
grad
,
THCudaLongTensor
*
arg
_grad
);
void
index_backward_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
grad
,
THCudaLongTensor
*
arg
);
#ifdef __cplusplus
#ifdef __cplusplus
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment