Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
38c8b3ac
Commit
38c8b3ac
authored
Aug 06, 2018
by
rusty1s
Browse files
cuda kernel
parent
a2f18da3
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
605 additions
and
547 deletions
+605
-547
cuda/atomics.cuh
cuda/atomics.cuh
+228
-0
cuda/index.cuh
cuda/index.cuh
+108
-0
cuda/scatter.cpp
cuda/scatter.cpp
+65
-0
cuda/scatter_kernel.cu
cuda/scatter_kernel.cu
+189
-0
test/test_backward.py
test/test_backward.py
+1
-2
test/utils.py
test/utils.py
+3
-0
torch_scatter/add.py
torch_scatter/add.py
+1
-1
torch_scatter/div.py
torch_scatter/div.py
+2
-2
torch_scatter/kernel/THCAtomics.cuh
torch_scatter/kernel/THCAtomics.cuh
+0
-157
torch_scatter/kernel/THCIndex.cuh
torch_scatter/kernel/THCIndex.cuh
+0
-89
torch_scatter/kernel/common.cuh
torch_scatter/kernel/common.cuh
+0
-37
torch_scatter/kernel/generic/common.cu
torch_scatter/kernel/generic/common.cu
+0
-22
torch_scatter/kernel/generic/kernel.cu
torch_scatter/kernel/generic/kernel.cu
+0
-77
torch_scatter/kernel/kernel.cu
torch_scatter/kernel/kernel.cu
+0
-97
torch_scatter/kernel/kernel.h
torch_scatter/kernel/kernel.h
+0
-55
torch_scatter/max.py
torch_scatter/max.py
+2
-2
torch_scatter/mean.py
torch_scatter/mean.py
+1
-1
torch_scatter/min.py
torch_scatter/min.py
+2
-2
torch_scatter/mul.py
torch_scatter/mul.py
+2
-2
torch_scatter/sub.py
torch_scatter/sub.py
+1
-1
No files found.
cuda/atomics.cuh
0 → 100644
View file @
38c8b3ac
#define ATOMIC(NAME) \
template <typename scalar, size_t size> struct Atomic##NAME##IntegerImpl; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 1> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \
uint32_t old = *address_as_ui; \
uint32_t shift = ((size_t)address & 3) * 8; \
uint32_t sum; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, scalar((old >> shift) & 0xff)); \
old = (old & ~(0x000000ff << shift)) | (sum << shift); \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = \
(uint32_t *)((char *)address - ((size_t)address & 2)); \
uint32_t old = *address_as_ui; \
uint32_t sum; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \
: scalar(old & 0xffff)); \
newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \
: (old & 0xffff0000) | sum; \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)address; \
uint32_t old = *address_as_ui; \
uint32_t assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long *address_as_ull = (unsigned long long *)address; \
unsigned long long old = *address_as_ull; \
unsigned long long assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
int *address_as_i = (int *)address; \
int old = *address_as_i; \
int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_i, assumed, \
__float_as_int(OP(val, __int_as_float(assumed)))); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long int *address_as_ull = \
(unsigned long long int *)address; \
unsigned long long int old = *address_as_ull; \
unsigned long long int assumed; \
\
do { \
assumed = old; \
old = atomicCAS( \
address_as_ull, assumed, \
__double_as_longlong(OP(val, __longlong_as_double(assumed)))); \
} while (assumed != old); \
} \
};
#define OP(X, Y) Y + X
ATOMIC
(
Add
)
#undef OP
static
inline
__device__
void
atomAdd
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicAddIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
int8_t
*
address
,
int8_t
val
)
{
AtomicAddIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
int16_t
*
address
,
int16_t
val
)
{
AtomicAddIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
int32_t
*
address
,
int32_t
val
)
{
atomicAdd
(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
int64_t
*
address
,
int64_t
val
)
{
AtomicAddIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
float
*
address
,
float
val
)
{
atomicAdd
(
address
,
val
);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
AtomicAddDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#else
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
atomicAdd
(
address
,
val
);
}
#endif
#define OP(X, Y) Y *X
ATOMIC
(
Mul
)
#undef OP
static
inline
__device__
void
atomMul
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicMulIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
int8_t
*
address
,
int8_t
val
)
{
AtomicMulIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
int16_t
*
address
,
int16_t
val
)
{
AtomicMulIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
int32_t
*
address
,
int32_t
val
)
{
AtomicMulIntegerImpl
<
int32_t
,
sizeof
(
int32_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
int64_t
*
address
,
int64_t
val
)
{
AtomicMulIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
float
*
address
,
float
val
)
{
AtomicMulDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
double
*
address
,
double
val
)
{
AtomicMulDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#define OP(X, Y) Y / X
ATOMIC
(
Div
)
#undef OP
static
inline
__device__
void
atomDiv
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicDivIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
int8_t
*
address
,
int8_t
val
)
{
AtomicDivIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
int16_t
*
address
,
int16_t
val
)
{
AtomicDivIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
int32_t
*
address
,
int32_t
val
)
{
AtomicDivIntegerImpl
<
int32_t
,
sizeof
(
int32_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
int64_t
*
address
,
int64_t
val
)
{
AtomicDivIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
float
*
address
,
float
val
)
{
AtomicDivDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
double
*
address
,
double
val
)
{
AtomicDivDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#define OP(X, Y) max(Y, X)
ATOMIC
(
Max
)
#undef OP
static
inline
__device__
void
atomMax
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicMaxIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
int8_t
*
address
,
int8_t
val
)
{
AtomicMaxIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
int16_t
*
address
,
int16_t
val
)
{
AtomicMaxIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
int32_t
*
address
,
int32_t
val
)
{
atomicMax
(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
int64_t
*
address
,
int64_t
val
)
{
AtomicMaxIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
float
*
address
,
float
val
)
{
AtomicMaxDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
double
*
address
,
double
val
)
{
AtomicMaxDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#define OP(X, Y) min(Y, X)
ATOMIC
(
Min
)
#undef OP
static
inline
__device__
void
atomMin
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicMinIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
int8_t
*
address
,
int8_t
val
)
{
AtomicMinIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
int16_t
*
address
,
int16_t
val
)
{
AtomicMinIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
int32_t
*
address
,
int32_t
val
)
{
atomicMin
(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
int64_t
*
address
,
int64_t
val
)
{
AtomicMinIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
float
*
address
,
float
val
)
{
AtomicMinDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
double
*
address
,
double
val
)
{
AtomicMinDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
cuda/index.cuh
0 → 100644
View file @
38c8b3ac
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
template
<
typename
scalar1
,
typename
scalar2
,
int64_t
Dims
>
struct
IndexToScatterOffsets3
{
static
__device__
void
compute
(
int64_t
i
,
const
int64_t
dim
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
&
index
,
int64_t
*
indexOffset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar1
,
int64_t
>
&
t1
,
int64_t
*
t1Offset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar2
,
int64_t
>
&
t2
,
int64_t
*
t2Offset
)
{
for
(
int64_t
d
=
Dims
-
1
;
d
>=
0
;
d
--
)
{
int64_t
curDimIndex
=
i
%
index
.
sizes
[
d
];
*
indexOffset
+=
curDimIndex
*
index
.
strides
[
d
];
*
t1Offset
+=
curDimIndex
*
t1
.
strides
[
d
];
if
(
d
!=
dim
)
{
*
t2Offset
+=
curDimIndex
*
t2
.
strides
[
d
];
}
i
/=
index
.
sizes
[
d
];
}
int64_t
indexValue
=
index
.
data
[
*
indexOffset
];
*
t2Offset
+=
indexValue
*
t2
.
strides
[
dim
];
}
};
template
<
typename
scalar1
,
typename
scalar2
>
struct
IndexToScatterOffsets3
<
scalar1
,
scalar2
,
-
1
>
{
static
__device__
void
compute
(
int64_t
i
,
const
int64_t
dim
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
&
index
,
int64_t
*
indexOffset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar1
,
int64_t
>
&
t1
,
int64_t
*
t1Offset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar2
,
int64_t
>
&
t2
,
int64_t
*
t2Offset
)
{
for
(
int64_t
d
=
index
.
dims
-
1
;
d
>=
0
;
d
--
)
{
int64_t
curDimIndex
=
i
%
index
.
sizes
[
d
];
*
indexOffset
+=
curDimIndex
*
index
.
strides
[
d
];
*
t1Offset
+=
curDimIndex
*
t1
.
strides
[
d
];
if
(
d
!=
dim
)
{
*
t2Offset
+=
curDimIndex
*
t2
.
strides
[
d
];
}
i
/=
index
.
sizes
[
d
];
}
int64_t
indexValue
=
index
.
data
[
*
indexOffset
];
*
t2Offset
+=
indexValue
*
t2
.
strides
[
dim
];
}
};
template
<
typename
scalar1
,
typename
scalar2
,
typename
scalar3
,
int64_t
Dims
>
struct
IndexToScatterOffsets4
{
static
__device__
void
compute
(
int64_t
i
,
const
int64_t
dim
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
&
index
,
int64_t
*
indexOffset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar1
,
int64_t
>
&
t1
,
int64_t
*
t1Offset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar2
,
int64_t
>
&
t2
,
int64_t
*
t2Offset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar3
,
int64_t
>
&
t3
,
int64_t
*
t3Offset
)
{
for
(
int64_t
d
=
Dims
-
1
;
d
>=
0
;
d
--
)
{
int64_t
curDimIndex
=
i
%
index
.
sizes
[
d
];
*
indexOffset
+=
curDimIndex
*
index
.
strides
[
d
];
*
t1Offset
+=
curDimIndex
*
t1
.
strides
[
d
];
if
(
d
!=
dim
)
{
*
t2Offset
+=
curDimIndex
*
t2
.
strides
[
d
];
*
t3Offset
+=
curDimIndex
*
t3
.
strides
[
d
];
}
i
/=
index
.
sizes
[
d
];
}
int64_t
indexValue
=
index
.
data
[
*
indexOffset
];
*
t2Offset
+=
indexValue
*
t2
.
strides
[
dim
];
*
t3Offset
+=
indexValue
*
t3
.
strides
[
dim
];
}
};
template
<
typename
scalar1
,
typename
scalar2
,
typename
scalar3
>
struct
IndexToScatterOffsets4
<
scalar1
,
scalar2
,
scalar3
,
-
1
>
{
static
__device__
void
compute
(
int64_t
i
,
const
int64_t
dim
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
&
index
,
int64_t
*
indexOffset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar1
,
int64_t
>
&
t1
,
int64_t
*
t1Offset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar2
,
int64_t
>
&
t2
,
int64_t
*
t2Offset
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar3
,
int64_t
>
&
t3
,
int64_t
*
t3Offset
)
{
for
(
int64_t
d
=
index
.
dims
-
1
;
d
>=
0
;
d
--
)
{
int64_t
curDimIndex
=
i
%
index
.
sizes
[
d
];
*
indexOffset
+=
curDimIndex
*
index
.
strides
[
d
];
*
t1Offset
+=
curDimIndex
*
t1
.
strides
[
d
];
if
(
d
!=
dim
)
{
*
t2Offset
+=
curDimIndex
*
t2
.
strides
[
d
];
*
t3Offset
+=
curDimIndex
*
t3
.
strides
[
d
];
}
i
/=
index
.
sizes
[
d
];
}
int64_t
indexValue
=
index
.
data
[
*
indexOffset
];
*
t2Offset
+=
indexValue
*
t2
.
strides
[
dim
];
*
t3Offset
+=
indexValue
*
t3
.
strides
[
dim
];
}
};
cuda/scatter.cpp
0 → 100644
View file @
38c8b3ac
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
void
scatter_mul_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
);
void
scatter_div_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
);
void
scatter_max_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
);
void
scatter_min_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
);
void
index_backward_cuda
(
at
::
Tensor
grad
,
at
::
Tensor
index
,
at
::
Tensor
arg
,
at
::
Tensor
out
,
int64_t
dim
);
void
scatter_mul
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
scatter_mul_cuda
(
src
,
index
,
out
,
dim
);
}
void
scatter_div
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
scatter_div_cuda
(
src
,
index
,
out
,
dim
);
}
void
scatter_max
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
CHECK_CUDA
(
arg
);
scatter_max_cuda
(
src
,
index
,
out
,
arg
,
dim
);
}
void
scatter_min
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
CHECK_CUDA
(
arg
);
scatter_min_cuda
(
src
,
index
,
out
,
arg
,
dim
);
}
void
index_backward
(
at
::
Tensor
grad
,
at
::
Tensor
index
,
at
::
Tensor
arg
,
at
::
Tensor
out
,
int64_t
dim
)
{
CHECK_CUDA
(
grad
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
arg
);
CHECK_CUDA
(
out
);
index_backward_cuda
(
grad
,
index
,
arg
,
out
,
dim
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"scatter_mul"
,
&
scatter_mul
,
"Scatter Mul (CUDA)"
);
m
.
def
(
"scatter_div"
,
&
scatter_div
,
"Scatter Div (CUDA)"
);
m
.
def
(
"scatter_max"
,
&
scatter_max
,
"Scatter Max (CUDA)"
);
m
.
def
(
"scatter_min"
,
&
scatter_min
,
"Scatter Min (CUDA)"
);
m
.
def
(
"index_backward"
,
&
index_backward
,
"Index Backward (CUDA)"
);
}
cuda/scatter_kernel.cu
0 → 100644
View file @
38c8b3ac
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "atomics.cuh"
#include "index.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define KERNEL_RUN(NAME, DIMS, N, ...) \
[&] { \
switch (DIMS) { \
case 1: \
NAME<scalar_t, 1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
break; \
case 2: \
NAME<scalar_t, 2><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
break; \
case 3: \
NAME<scalar_t, 3><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
break; \
default: \
NAME<scalar_t, -1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N); \
} \
}()
template
<
typename
scalar_t
,
int64_t
Dims
>
__global__
void
scatter_mul_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
src
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
index
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
out
,
int64_t
dim
,
size_t
numel
)
{
const
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
idx
;
i
<
numel
;
i
+=
stride
)
{
int64_t
srcOffset
=
0
,
indexOffset
=
0
,
outOffset
=
0
;
IndexToScatterOffsets3
<
scalar_t
,
scalar_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
src
,
&
srcOffset
,
out
,
&
outOffset
);
atomMul
(
&
out
.
data
[
outOffset
],
src
.
data
[
srcOffset
]);
}
}
void
scatter_mul_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_mul_kernel"
,
[
&
]
{
KERNEL_RUN
(
scatter_mul_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
out
),
dim
);
});
}
template
<
typename
scalar_t
,
int64_t
Dims
>
__global__
void
scatter_div_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
src
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
index
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
out
,
int64_t
dim
,
size_t
numel
)
{
const
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
idx
;
i
<
numel
;
i
+=
stride
)
{
int64_t
srcOffset
=
0
,
indexOffset
=
0
,
outOffset
=
0
;
IndexToScatterOffsets3
<
scalar_t
,
scalar_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
src
,
&
srcOffset
,
out
,
&
outOffset
);
atomDiv
(
&
out
.
data
[
outOffset
],
src
.
data
[
srcOffset
]);
}
}
void
scatter_div_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_div_kernel"
,
[
&
]
{
KERNEL_RUN
(
scatter_div_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
out
),
dim
);
});
}
template
<
typename
scalar_t
,
int64_t
Dims
>
__global__
void
arg_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
src
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
index
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
out
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
arg
,
int64_t
dim
,
size_t
numel
)
{
const
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
idx
;
i
<
numel
;
i
+=
stride
)
{
int64_t
srcOffset
=
0
,
indexOffset
=
0
,
outOffset
=
0
,
argOffset
=
0
;
IndexToScatterOffsets4
<
scalar_t
,
scalar_t
,
int64_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
src
,
&
srcOffset
,
out
,
&
outOffset
,
arg
,
&
argOffset
);
if
(
src
.
data
[
srcOffset
]
==
out
.
data
[
outOffset
])
{
arg
.
data
[
argOffset
]
=
(
srcOffset
/
src
.
strides
[
dim
])
%
src
.
sizes
[
dim
];
}
}
}
template
<
typename
scalar_t
,
int64_t
Dims
>
__global__
void
scatter_max_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
src
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
index
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
out
,
int64_t
dim
,
size_t
numel
)
{
const
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
idx
;
i
<
numel
;
i
+=
stride
)
{
int64_t
srcOffset
=
0
,
indexOffset
=
0
,
outOffset
=
0
;
IndexToScatterOffsets3
<
scalar_t
,
scalar_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
src
,
&
srcOffset
,
out
,
&
outOffset
);
atomMax
(
&
out
.
data
[
outOffset
],
src
.
data
[
srcOffset
]);
}
}
void
scatter_max_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_max_kernel"
,
[
&
]
{
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
auto
out_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
out
);
KERNEL_RUN
(
scatter_max_kernel
,
index
.
dim
(),
index
.
numel
(),
src_info
,
index_info
,
out_info
,
dim
);
KERNEL_RUN
(
arg_kernel
,
index
.
dim
(),
index
.
numel
(),
src_info
,
index_info
,
out_info
,
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
arg
),
dim
);
});
}
template
<
typename
scalar_t
,
int64_t
Dims
>
__global__
void
scatter_min_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
src
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
index
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
out
,
int64_t
dim
,
size_t
numel
)
{
const
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
idx
;
i
<
numel
;
i
+=
stride
)
{
int64_t
srcOffset
=
0
,
indexOffset
=
0
,
outOffset
=
0
;
IndexToScatterOffsets3
<
scalar_t
,
scalar_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
src
,
&
srcOffset
,
out
,
&
outOffset
);
atomMin
(
&
out
.
data
[
outOffset
],
src
.
data
[
srcOffset
]);
}
}
void
scatter_min_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_min_kernel"
,
[
&
]
{
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
auto
out_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
out
);
KERNEL_RUN
(
scatter_min_kernel
,
index
.
dim
(),
index
.
numel
(),
src_info
,
index_info
,
out_info
,
dim
);
KERNEL_RUN
(
arg_kernel
,
index
.
dim
(),
index
.
numel
(),
src_info
,
index_info
,
out_info
,
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
arg
),
dim
);
});
}
template
<
typename
scalar_t
,
int64_t
Dims
>
__global__
void
index_backward_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
index
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
arg
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
out
,
int64_t
dim
,
size_t
numel
)
{
const
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
idx
;
i
<
numel
;
i
+=
stride
)
{
int64_t
gradOffset
=
0
,
indexOffset
=
0
,
argOffset
=
0
,
outOffset
=
0
;
IndexToScatterOffsets4
<
scalar_t
,
int64_t
,
scalar_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
out
,
&
outOffset
,
arg
,
&
argOffset
,
grad
,
&
gradOffset
);
if
(
arg
.
data
[
argOffset
]
==
(
outOffset
/
out
.
strides
[
dim
])
%
out
.
sizes
[
dim
])
{
out
.
data
[
outOffset
]
=
grad
.
data
[
gradOffset
];
}
}
}
void
index_backward_cuda
(
at
::
Tensor
grad
,
at
::
Tensor
index
,
at
::
Tensor
arg
,
at
::
Tensor
out
,
int64_t
dim
)
{
AT_DISPATCH_ALL_TYPES
(
grad
.
type
(),
"index_backward_kernel"
,
[
&
]
{
KERNEL_RUN
(
index_backward_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
arg
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
out
),
dim
);
});
}
test/test_backward.py
View file @
38c8b3ac
...
@@ -5,9 +5,8 @@ import torch
...
@@ -5,9 +5,8 @@ import torch
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
import
torch_scatter
import
torch_scatter
from
.utils
import
devices
,
tensor
from
.utils
import
grad_dtypes
as
dtypes
,
devices
,
tensor
dtypes
=
[
torch
.
float
,
torch
.
double
]
funcs
=
[
'add'
,
'sub'
,
'mul'
,
'div'
,
'mean'
]
funcs
=
[
'add'
,
'sub'
,
'mul'
,
'div'
,
'mean'
]
indices
=
[
2
,
0
,
1
,
1
,
0
]
indices
=
[
2
,
0
,
1
,
1
,
0
]
...
...
test/utils.py
View file @
38c8b3ac
...
@@ -3,6 +3,9 @@ from torch.testing import get_all_dtypes
...
@@ -3,6 +3,9 @@ from torch.testing import get_all_dtypes
dtypes
=
get_all_dtypes
()
dtypes
=
get_all_dtypes
()
dtypes
.
remove
(
torch
.
half
)
dtypes
.
remove
(
torch
.
half
)
dtypes
.
remove
(
torch
.
short
)
# PyTorch scatter does not work on short types.
grad_dtypes
=
[
torch
.
float
,
torch
.
double
]
devices
=
[
torch
.
device
(
'cpu'
)]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
# pragma: no cover
if
torch
.
cuda
.
is_available
():
# pragma: no cover
...
...
torch_scatter/add.py
View file @
38c8b3ac
from
.utils.gen
import
gen
from
torch_scatter
.utils.gen
import
gen
def
scatter_add
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
def
scatter_add
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
...
...
torch_scatter/div.py
View file @
38c8b3ac
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
.utils.ext
import
get_func
from
torch_scatter
.utils.ext
import
get_func
from
.utils.gen
import
gen
from
torch_scatter
.utils.gen
import
gen
class
ScatterDiv
(
Function
):
class
ScatterDiv
(
Function
):
...
...
torch_scatter/kernel/THCAtomics.cuh
deleted
100644 → 0
View file @
a2f18da3
#define ATOMIC_(NAME) \
template <typename T, size_t n> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl); \
\
template<typename T> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl)<T, 1> { \
inline __device__ void operator()(T *address, T val) { \
uint32_t *address_as_ui = (uint32_t *) (address - ((size_t) address & 3)); \
uint32_t old = *address_as_ui; \
uint32_t shift = ((size_t) address & 3) * 8; \
uint32_t res; \
uint32_t assumed; \
\
do { \
assumed = old; \
res = OP(val, T((old >> shift) & 0xff)); \
old = (old & ~(0x000000ff << shift)) | (res << shift); \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template<typename T> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl)<T, 2> { \
inline __device__ void operator()(T *address, T val) { \
uint32_t *address_as_ui = (uint32_t *) ((char *) address - ((size_t) address & 2)); \
uint32_t old = *address_as_ui; \
uint32_t res; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
res = OP(val, (size_t) address & 2 ? T(old >> 16) : T(old & 0xffff)); \
newval = (size_t) address & 2 ? (old & 0xffff) | (res << 16) : (old & 0xffff0000) | res; \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template<typename T> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl)<T, 4> { \
inline __device__ void operator()(T *address, T val) { \
uint32_t *address_as_ui = (uint32_t *) address; \
uint32_t old = *address_as_ui; \
uint32_t assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ui, assumed, OP(val, (T) old)); \
} while (assumed != old); \
} \
}; \
\
template<typename T> \
struct TH_CONCAT_3(Atomic, NAME, IntegerImpl)<T, 8> { \
inline __device__ void operator()(T *address, T val) { \
unsigned long long *address_as_ull = (unsigned long long *) address; \
unsigned long long old = *address_as_ull; \
unsigned long long assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, OP(val, (T) old)); \
} while (assumed != old); \
} \
}; \
\
template <typename T, size_t n> \
struct TH_CONCAT_3(Atomic, NAME, DecimalImpl); \
\
template <typename T> \
struct TH_CONCAT_3(Atomic, NAME, DecimalImpl)<T, 4> { \
inline __device__ void operator()(T *address, T val) { \
int *address_as_i = (int *) address; \
int old = *address_as_i; \
int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_i, assumed, __float_as_int(OP(val, __int_as_float(assumed)))); \
} while (assumed != old); \
} \
}; \
\
template <typename T> \
struct TH_CONCAT_3(Atomic, NAME, DecimalImpl)<T, 8> { \
inline __device__ void operator()(T *address, T val) { \
unsigned long long int *address_as_ull = (unsigned long long int *) address; \
unsigned long long int old = *address_as_ull; \
unsigned long long int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(OP(val, __longlong_as_double(assumed)))); \
} while (assumed != old); \
} \
};
#define OP(X, Y) Y + X
ATOMIC_
(
Add
)
#undef OP
static
inline
__device__
void
atomAdd
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicAddIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
int8_t
*
address
,
int8_t
val
)
{
AtomicAddIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
int16_t
*
address
,
int16_t
val
)
{
AtomicAddIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
int32_t
*
address
,
int32_t
val
)
{
atomicAdd
(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
int64_t
*
address
,
int64_t
val
)
{
AtomicAddIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomAdd
(
float
*
address
,
float
val
)
{
atomicAdd
(
address
,
val
);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
AtomicAddDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#else
static
inline
__device__
void
atomAdd
(
double
*
address
,
double
val
)
{
atomicAdd
(
address
,
val
);
}
#endif
#define OP(X, Y) Y * X
ATOMIC_
(
Mul
)
#undef OP
static
inline
__device__
void
atomMul
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicMulIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
int8_t
*
address
,
int8_t
val
)
{
AtomicMulIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
int16_t
*
address
,
int16_t
val
)
{
AtomicMulIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
int32_t
*
address
,
int32_t
val
)
{
AtomicMulIntegerImpl
<
int32_t
,
sizeof
(
int32_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
int64_t
*
address
,
int64_t
val
)
{
AtomicMulIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
float
*
address
,
float
val
)
{
AtomicMulDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMul
(
double
*
address
,
double
val
)
{
AtomicMulDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#define OP(X, Y) Y / X
ATOMIC_
(
Div
)
#undef OP
static
inline
__device__
void
atomDiv
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicDivIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
int8_t
*
address
,
int8_t
val
)
{
AtomicDivIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
int16_t
*
address
,
int16_t
val
)
{
AtomicDivIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
int32_t
*
address
,
int32_t
val
)
{
AtomicDivIntegerImpl
<
int32_t
,
sizeof
(
int32_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
int64_t
*
address
,
int64_t
val
)
{
AtomicDivIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
float
*
address
,
float
val
)
{
AtomicDivDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomDiv
(
double
*
address
,
double
val
)
{
AtomicDivDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#define OP(X, Y) max(Y, X)
ATOMIC_
(
Max
)
#undef OP
static
inline
__device__
void
atomMax
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicMaxIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
int8_t
*
address
,
int8_t
val
)
{
AtomicMaxIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
int16_t
*
address
,
int16_t
val
)
{
AtomicMaxIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
int32_t
*
address
,
int32_t
val
)
{
atomicMax
(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
int64_t
*
address
,
int64_t
val
)
{
AtomicMaxIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
float
*
address
,
float
val
)
{
AtomicMaxDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMax
(
double
*
address
,
double
val
)
{
AtomicMaxDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
#define OP(X, Y) min(Y, X)
ATOMIC_
(
Min
)
#undef OP
static
inline
__device__
void
atomMin
(
uint8_t
*
address
,
uint8_t
val
)
{
AtomicMinIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
int8_t
*
address
,
int8_t
val
)
{
AtomicMinIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
int16_t
*
address
,
int16_t
val
)
{
AtomicMinIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
int32_t
*
address
,
int32_t
val
)
{
atomicMin
(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
int64_t
*
address
,
int64_t
val
)
{
AtomicMinIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
float
*
address
,
float
val
)
{
AtomicMinDecimalImpl
<
float
,
sizeof
(
float
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomMin
(
double
*
address
,
double
val
)
{
AtomicMinDecimalImpl
<
double
,
sizeof
(
double
)
>
()(
address
,
val
);
}
torch_scatter/kernel/THCIndex.cuh
deleted
100644 → 0
View file @
a2f18da3
template
<
typename
a
,
typename
b
,
int
Dims
>
struct
IndexToScatterOffsets3
{
static
__device__
void
compute
(
int
i
,
const
int
dim
,
const
TensorInfo
<
int64_t
>&
index
,
int
*
indexOffset
,
const
TensorInfo
<
a
>&
t1
,
int
*
t1Offset
,
const
TensorInfo
<
b
>&
t2
,
int
*
t2Offset
)
{
int
curDimIndex
;
for
(
int
d
=
Dims
-
1
;
d
>=
0
;
d
--
)
{
curDimIndex
=
i
%
index
.
size
[
d
];
*
indexOffset
+=
curDimIndex
*
index
.
stride
[
d
];
*
t1Offset
+=
curDimIndex
*
t1
.
stride
[
d
];
if
(
d
!=
dim
)
*
t2Offset
+=
curDimIndex
*
t2
.
stride
[
d
];
i
/=
index
.
size
[
d
];
}
int64_t
indexValue
=
index
.
data
[
*
indexOffset
];
assert
(
indexValue
>=
0
&&
indexValue
<
t2
.
size
[
dim
]);
*
t2Offset
+=
indexValue
*
t2
.
stride
[
dim
];
}
};
template
<
typename
a
,
typename
b
>
struct
IndexToScatterOffsets3
<
a
,
b
,
-
1
>
{
static
__device__
void
compute
(
int
i
,
const
int
dim
,
const
TensorInfo
<
int64_t
>&
index
,
int
*
indexOffset
,
const
TensorInfo
<
a
>&
t1
,
int
*
t1Offset
,
const
TensorInfo
<
b
>&
t2
,
int
*
t2Offset
)
{
int
curDimIndex
;
for
(
int
d
=
index
.
dims
-
1
;
d
>=
0
;
d
--
)
{
curDimIndex
=
i
%
index
.
size
[
d
];
*
indexOffset
+=
curDimIndex
*
index
.
stride
[
d
];
*
t1Offset
+=
curDimIndex
*
t1
.
stride
[
d
];
if
(
d
!=
dim
)
*
t2Offset
+=
curDimIndex
*
t2
.
stride
[
d
];
i
/=
index
.
size
[
d
];
}
int64_t
indexValue
=
index
.
data
[
*
indexOffset
];
assert
(
indexValue
>=
0
&&
indexValue
<
t2
.
size
[
dim
]);
*
t2Offset
+=
indexValue
*
t2
.
stride
[
dim
];
}
};
template
<
typename
a
,
typename
b
,
typename
c
,
int
Dims
>
struct
IndexToScatterOffsets4
{
static
__device__
void
compute
(
int
i
,
const
int
dim
,
const
TensorInfo
<
int64_t
>&
index
,
int
*
indexOffset
,
const
TensorInfo
<
a
>&
t1
,
int
*
t1Offset
,
const
TensorInfo
<
b
>&
t2
,
int
*
t2Offset
,
const
TensorInfo
<
c
>&
t3
,
int
*
t3Offset
)
{
int
curDimIndex
;
for
(
int
d
=
Dims
-
1
;
d
>=
0
;
d
--
)
{
curDimIndex
=
i
%
index
.
size
[
d
];
*
indexOffset
+=
curDimIndex
*
index
.
stride
[
d
];
*
t1Offset
+=
curDimIndex
*
t1
.
stride
[
d
];
if
(
d
!=
dim
)
{
*
t2Offset
+=
curDimIndex
*
t2
.
stride
[
d
];
*
t3Offset
+=
curDimIndex
*
t3
.
stride
[
d
];
}
i
/=
index
.
size
[
d
];
}
int64_t
indexValue
=
index
.
data
[
*
indexOffset
];
assert
(
indexValue
>=
0
&&
indexValue
<
t2
.
size
[
dim
]);
*
t2Offset
+=
indexValue
*
t2
.
stride
[
dim
];
*
t3Offset
+=
indexValue
*
t3
.
stride
[
dim
];
}
};
template
<
typename
a
,
typename
b
,
typename
c
>
struct
IndexToScatterOffsets4
<
a
,
b
,
c
,
-
1
>
{
static
__device__
void
compute
(
int
i
,
const
int
dim
,
const
TensorInfo
<
int64_t
>&
index
,
int
*
indexOffset
,
const
TensorInfo
<
a
>&
t1
,
int
*
t1Offset
,
const
TensorInfo
<
b
>&
t2
,
int
*
t2Offset
,
const
TensorInfo
<
c
>&
t3
,
int
*
t3Offset
)
{
int
curDimIndex
;
for
(
int
d
=
index
.
dims
-
1
;
d
>=
0
;
d
--
)
{
curDimIndex
=
i
%
index
.
size
[
d
];
*
indexOffset
+=
curDimIndex
*
index
.
stride
[
d
];
*
t1Offset
+=
curDimIndex
*
t1
.
stride
[
d
];
if
(
d
!=
dim
)
{
*
t2Offset
+=
curDimIndex
*
t2
.
stride
[
d
];
*
t3Offset
+=
curDimIndex
*
t3
.
stride
[
d
];
}
i
/=
index
.
size
[
d
];
}
int64_t
indexValue
=
index
.
data
[
*
indexOffset
];
assert
(
indexValue
>=
0
&&
indexValue
<
t2
.
size
[
dim
]);
*
t2Offset
+=
indexValue
*
t2
.
stride
[
dim
];
*
t3Offset
+=
indexValue
*
t3
.
stride
[
dim
];
}
};
torch_scatter/kernel/common.cuh
deleted
100644 → 0
View file @
a2f18da3
const
int
MAX_DIMS
=
25
;
const
int
NUM_THREADS
=
1024
;
inline
int
GET_BLOCKS
(
const
int
n
)
{
return
(
n
+
NUM_THREADS
-
1
)
/
NUM_THREADS
;
}
template
<
typename
T
>
struct
TensorInfo
{
TensorInfo
(
T
*
t
,
int
d
,
int
sz
[
MAX_DIMS
],
int
st
[
MAX_DIMS
])
{
data
=
t
;
dims
=
d
;
for
(
int
i
=
0
;
i
<
dims
;
i
++
)
{
size
[
i
]
=
sz
[
i
];
stride
[
i
]
=
st
[
i
];
}
}
T
*
data
;
int
dims
;
int
size
[
MAX_DIMS
];
int
stride
[
MAX_DIMS
];
};
#define KERNEL_LOOP(I, N) \
for (int I = blockIdx.x * blockDim.x + threadIdx.x; I < N; i += blockDim.x * gridDim.x)
#define KERNEL_RUN(NAME, DIMS, N, ...) { \
int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \
switch (DIMS) { \
case 1: NAME<real, 1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 2: NAME<real, 2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 3: NAME<real, 3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
} \
THCudaCheck(cudaGetLastError()); \
}
torch_scatter/kernel/generic/common.cu
deleted
100644 → 0
View file @
a2f18da3
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/common.cu"
#else
void
thc_
(
check
)(
THCState
*
state
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
)
{
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
2
,
output
,
input
));
THCAssertSameGPU
(
THCudaLongTensor_checkGPU
(
state
,
1
,
index
));
THArgCheck
(
THCTensor_
(
nDimension
)(
state
,
output
)
<=
MAX_DIMS
,
1
,
"Tensor too large or too many dimensions"
);
}
TensorInfo
<
real
>
thc_
(
getTensorInfo
)(
THCState
*
state
,
THCTensor
*
tensor
)
{
real
*
data
=
THCTensor_
(
data
)(
state
,
tensor
);
int
dims
=
THCTensor_
(
nDimension
)(
state
,
tensor
);
int
size
[
MAX_DIMS
];
int
stride
[
MAX_DIMS
];
for
(
int
i
=
0
;
i
<
dims
;
i
++
)
{
size
[
i
]
=
THCTensor_
(
size
)(
state
,
tensor
,
i
);
stride
[
i
]
=
THCTensor_
(
stride
)(
state
,
tensor
,
i
);
}
return
TensorInfo
<
real
>
(
data
,
dims
,
size
,
stride
);
}
#endif
torch_scatter/kernel/generic/kernel.cu
deleted
100644 → 0
View file @
a2f18da3
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/kernel.cu"
#else
void
scatter_
(
mul
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
KERNEL_RUN
(
mulKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
dim
)
}
void
scatter_
(
div
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
KERNEL_RUN
(
divKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
dim
)
}
void
scatter_
(
mean
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCTensor
*
count
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
real
>
countInfo
=
thc_
(
getTensorInfo
)(
state
,
count
);
KERNEL_RUN
(
meanKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
countInfo
,
dim
)
}
void
scatter_
(
max
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
int64_t
>
argInfo
=
thc_getTensorInfo_Long
(
state
,
arg
);
KERNEL_RUN
(
maxKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
dim
)
KERNEL_RUN
(
argKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
argInfo
,
dim
)
}
void
scatter_
(
min
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
int64_t
>
argInfo
=
thc_getTensorInfo_Long
(
state
,
arg
);
KERNEL_RUN
(
minKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
dim
)
KERNEL_RUN
(
argKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
argInfo
,
dim
)
}
void
index_backward
(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
grad
,
THCudaLongTensor
*
arg
)
{
thc_
(
check
)(
state
,
output
,
index
,
grad
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
gradInfo
=
thc_
(
getTensorInfo
)(
state
,
grad
);
TensorInfo
<
int64_t
>
argInfo
=
thc_getTensorInfo_Long
(
state
,
arg
);
KERNEL_RUN
(
indexBackwardKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
gradInfo
,
argInfo
,
dim
)
}
#endif
torch_scatter/kernel/kernel.cu
deleted
100644 → 0
View file @
a2f18da3
#include <THC.h>
#include "kernel.h"
#include "common.cuh"
#include "THCIndex.cuh"
#include "THCAtomics.cuh"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _kernel_, Real)
#define index_backward TH_CONCAT_2(index_backward_kernel_, Real)
#define thc_(NAME) TH_CONCAT_4(thc_, NAME, _, Real)
#include "generic/common.cu"
#include "THCGenerateAllTypes.h"
template
<
typename
Real
,
int
Dims
>
__global__
void
mulKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
IndexToScatterOffsets3
<
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
);
atomMul
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
divKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
IndexToScatterOffsets3
<
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
);
atomDiv
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
meanKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
Real
>
count
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
int
countOffset
=
0
;
IndexToScatterOffsets4
<
Real
,
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
,
count
,
&
countOffset
);
atomAdd
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
atomAdd
(
&
count
.
data
[
countOffset
],
1
);
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
maxKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
IndexToScatterOffsets3
<
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
);
atomMax
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
minKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
IndexToScatterOffsets3
<
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
);
atomMin
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
argKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
int64_t
>
arg
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
int
argOffset
=
0
;
IndexToScatterOffsets4
<
Real
,
Real
,
int64_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
,
arg
,
&
argOffset
);
if
(
input
.
data
[
inputOffset
]
==
output
.
data
[
outputOffset
])
{
arg
.
data
[
argOffset
]
=
(
inputOffset
/
input
.
stride
[
dim
])
%
input
.
size
[
dim
];
}
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
indexBackwardKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
grad
,
TensorInfo
<
int64_t
>
arg
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
gradOffset
=
0
;
int
argOffset
=
0
;
IndexToScatterOffsets4
<
Real
,
Real
,
int64_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
output
,
&
outputOffset
,
grad
,
&
gradOffset
,
arg
,
&
argOffset
);
if
(
arg
.
data
[
argOffset
]
==
(
outputOffset
/
output
.
stride
[
dim
])
%
output
.
size
[
dim
])
{
output
.
data
[
outputOffset
]
=
grad
.
data
[
gradOffset
];
}
}
}
#include "generic/kernel.cu"
#include "THCGenerateFloatType.h"
#include "generic/kernel.cu"
#include "THCGenerateDoubleType.h"
#include "generic/kernel.cu"
#include "THCGenerateByteType.h"
#include "generic/kernel.cu"
#include "THCGenerateCharType.h"
#include "generic/kernel.cu"
#include "THCGenerateShortType.h"
#include "generic/kernel.cu"
#include "THCGenerateIntType.h"
#include "generic/kernel.cu"
#include "THCGenerateLongType.h"
torch_scatter/kernel/kernel.h
deleted
100644 → 0
View file @
a2f18da3
#ifdef __cplusplus
extern
"C"
{
#endif
void
scatter_mul_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
);
void
scatter_mul_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
);
void
scatter_mul_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
);
void
scatter_mul_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
);
void
scatter_mul_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
);
void
scatter_mul_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
);
void
scatter_mul_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
);
void
scatter_div_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
);
void
scatter_div_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
);
void
scatter_div_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
);
void
scatter_div_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
);
void
scatter_div_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
);
void
scatter_div_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
);
void
scatter_div_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
);
void
scatter_mean_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaTensor
*
count
);
void
scatter_mean_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaDoubleTensor
*
count
);
void
scatter_mean_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaByteTensor
*
count
);
void
scatter_mean_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaCharTensor
*
count
);
void
scatter_mean_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaShortTensor
*
count
);
void
scatter_mean_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaIntTensor
*
count
);
void
scatter_mean_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
count
);
void
scatter_max_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_max_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
input
,
THCudaLongTensor
*
arg
);
void
scatter_min_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
input
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Float
(
THCState
*
state
,
int
dim
,
THCudaTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Double
(
THCState
*
state
,
int
dim
,
THCudaDoubleTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaDoubleTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Byte
(
THCState
*
state
,
int
dim
,
THCudaByteTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaByteTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Char
(
THCState
*
state
,
int
dim
,
THCudaCharTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaCharTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Short
(
THCState
*
state
,
int
dim
,
THCudaShortTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaShortTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Int
(
THCState
*
state
,
int
dim
,
THCudaIntTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaIntTensor
*
grad
,
THCudaLongTensor
*
arg
);
void
index_backward_kernel_Long
(
THCState
*
state
,
int
dim
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
index
,
THCudaLongTensor
*
grad
,
THCudaLongTensor
*
arg
);
#ifdef __cplusplus
}
#endif
torch_scatter/max.py
View file @
38c8b3ac
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
.utils.ext
import
get_func
from
torch_scatter
.utils.ext
import
get_func
from
.utils.gen
import
gen
from
torch_scatter
.utils.gen
import
gen
class
ScatterMax
(
Function
):
class
ScatterMax
(
Function
):
...
...
torch_scatter/mean.py
View file @
38c8b3ac
import
torch
import
torch
from
.add
import
scatter_add
from
torch_scatter
import
scatter_add
def
scatter_mean
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
def
scatter_mean
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
...
...
torch_scatter/min.py
View file @
38c8b3ac
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
.utils.ext
import
get_func
from
torch_scatter
.utils.ext
import
get_func
from
.utils.gen
import
gen
from
torch_scatter
.utils.gen
import
gen
class
ScatterMin
(
Function
):
class
ScatterMin
(
Function
):
...
...
torch_scatter/mul.py
View file @
38c8b3ac
from
torch.autograd
import
Function
from
torch.autograd
import
Function
from
.utils.ext
import
get_func
from
torch_scatter
.utils.ext
import
get_func
from
.utils.gen
import
gen
from
torch_scatter
.utils.gen
import
gen
class
ScatterMul
(
Function
):
class
ScatterMul
(
Function
):
...
...
torch_scatter/sub.py
View file @
38c8b3ac
from
.add
import
scatter_add
from
torch_scatter
import
scatter_add
def
scatter_sub
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
def
scatter_sub
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment