Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
3e06f342
Commit
3e06f342
authored
Dec 20, 2017
by
rusty1s
Browse files
argmax impl
parent
bdf2563a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
52 additions
and
16 deletions
+52
-16
test/test_max.py
test/test_max.py
+1
-0
torch_scatter/kernel/common.cuh
torch_scatter/kernel/common.cuh
+20
-10
torch_scatter/kernel/generic/kernel.cu
torch_scatter/kernel/generic/kernel.cu
+11
-3
torch_scatter/kernel/kernel.cu
torch_scatter/kernel/kernel.cu
+20
-3
No files found.
test/test_max.py
View file @
3e06f342
...
...
@@ -51,3 +51,4 @@ def test_scatter_cuda_max(str):
_
,
arg_output
=
scatter_max_
(
output
,
index
,
input
,
dim
=
1
)
print
(
output
)
print
(
arg_output
)
torch_scatter/kernel/common.cuh
View file @
3e06f342
...
...
@@ -25,13 +25,23 @@ struct TensorInfo {
for (int I = blockIdx.x * blockDim.x + threadIdx.x; I < N; i += blockDim.x * gridDim.x)
/* #define KERNEL_RUN(NAME, DIMS, N, PARAMS) \ */
#define KERNEL_RUN(NAME, DIMS, N, ...) \
int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \
switch (DIMS) { \
#define KERNEL_RUN(NAME, DIMS, N, ...)
{
\
int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \
switch (DIMS) { \
case 1: NAME<real, 1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 2: NAME<real, 2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 3: NAME<real, 3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
} \
THCudaCheck(cudaGetLastError());
} \
THCudaCheck(cudaGetLastError()); \
}
static
inline
__device__
bool
eq
(
uint8_t
a
,
uint8_t
b
)
{
return
a
==
b
;
}
static
inline
__device__
bool
eq
(
int8_t
a
,
int8_t
b
)
{
return
a
==
b
;
}
static
inline
__device__
bool
eq
(
int16_t
a
,
int16_t
b
)
{
return
a
==
b
;
}
static
inline
__device__
bool
eq
(
int32_t
a
,
int32_t
b
)
{
return
a
==
b
;
}
static
inline
__device__
bool
eq
(
int64_t
a
,
int64_t
b
)
{
return
a
==
b
;
}
static
inline
__device__
bool
eq
(
float
a
,
float
b
)
{
return
a
==
b
;
}
static
inline
__device__
bool
eq
(
double
a
,
double
b
)
{
return
a
==
b
;
}
static
inline
__device__
bool
eq
(
half
a
,
half
b
)
{
return
__half2float
(
a
)
==
__half2float
(
b
);
}
torch_scatter/kernel/generic/kernel.cu
View file @
3e06f342
...
...
@@ -26,13 +26,21 @@ void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
int64_t
>
argInfo
=
thc_getTensorInfo_Long
(
state
,
arg
);
KERNEL_RUN
(
maxKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
argInfo
,
dim
)
/*
KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, dim)
*/
KERNEL_RUN
(
maxKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
dim
)
KERNEL_RUN
(
argKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
argInfo
,
dim
)
}
void
scatter_
(
min
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
printf
(
"min"
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
int64_t
>
argInfo
=
thc_getTensorInfo_Long
(
state
,
arg
);
KERNEL_RUN
(
minKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
dim
)
KERNEL_RUN
(
argKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
argInfo
,
dim
)
}
void
index_backward
(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
grad
,
THCudaLongTensor
*
arg
)
{
...
...
torch_scatter/kernel/kernel.cu
View file @
3e06f342
...
...
@@ -14,12 +14,29 @@
#include "THCGenerateAllTypes.h"
template
<
typename
Real
,
int
Dims
>
__global__
void
maxKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
int64_t
>
arg
,
const
int
dim
,
const
int
n
)
{
__global__
void
maxKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;;
IndexToScatterOffsets3
<
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
);
atomicMax
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
minKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;;
IndexToScatterOffsets3
<
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
);
atomicMin
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
argKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
int64_t
>
arg
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
int
argOffset
=
0
;
IndexToScatterOffsets4
<
Real
,
Real
,
int64_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
,
arg
,
&
argOffset
);
atomicMax
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
// TODO: Do something with arg.
if
(
eq
(
input
.
data
[
inputOffset
],
output
.
data
[
outputOffset
]))
arg
.
data
[
argOffset
]
=
inputOffset
%
input
.
size
[
dim
];
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment