Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
6372815e
Commit
6372815e
authored
Dec 20, 2017
by
rusty1s
Browse files
first atomic max impl
parent
b3091036
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
148 additions
and
9 deletions
+148
-9
test/test_max.py
test/test_max.py
+1
-2
torch_scatter/kernel/THCAtomics.cuh
torch_scatter/kernel/THCAtomics.cuh
+106
-0
torch_scatter/kernel/THCTensorInfo.cuh
torch_scatter/kernel/THCTensorInfo.cuh
+0
-1
torch_scatter/kernel/common.cuh
torch_scatter/kernel/common.cuh
+15
-3
torch_scatter/kernel/generic/kernel.cu
torch_scatter/kernel/generic/kernel.cu
+6
-2
torch_scatter/kernel/kernel.cu
torch_scatter/kernel/kernel.cu
+20
-1
No files found.
test/test_max.py
View file @
6372815e
...
@@ -37,8 +37,7 @@ def test_scatter_max(str):
...
@@ -37,8 +37,7 @@ def test_scatter_max(str):
assert
input
.
grad
.
data
.
tolist
()
==
expected_grad_input
assert
input
.
grad
.
data
.
tolist
()
==
expected_grad_input
# @pytest.mark.parametrize('str', tensor_strs)
@
pytest
.
mark
.
parametrize
(
'str'
,
tensor_strs
)
@
pytest
.
mark
.
parametrize
(
'str'
,
[
'FloatTensor'
])
def
test_scatter_cuda_max
(
str
):
def
test_scatter_cuda_max
(
str
):
input
=
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]]
input
=
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]]
index
=
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]]
index
=
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]]
...
...
torch_scatter/kernel/THCAtomics.cuh
0 → 100644
View file @
6372815e
template
<
typename
T
,
size_t
n
>
struct
AtomicMaxIntegerImpl
;
template
<
typename
T
>
struct
AtomicMaxIntegerImpl
<
T
,
1
>
{
inline
__device__
void
operator
()(
T
*
address
,
T
val
)
{
uint32_t
*
address_as_ui
=
(
uint32_t
*
)
(
address
-
((
size_t
)
address
&
3
));
uint32_t
old
=
*
address_as_ui
;
uint32_t
shift
=
(((
size_t
)
address
&
3
)
*
8
);
uint32_t
sum
;
uint32_t
assumed
;
do
{
assumed
=
old
;
sum
=
val
+
T
((
old
>>
shift
)
&
0xff
);
old
=
(
old
&
~
(
0x000000ff
<<
shift
))
|
(
sum
<<
shift
);
old
=
atomicCAS
(
address_as_ui
,
assumed
,
old
);
}
while
(
assumed
!=
old
);
}
};
template
<
typename
T
>
struct
AtomicMaxIntegerImpl
<
T
,
2
>
{
inline
__device__
void
operator
()(
T
*
address
,
T
val
)
{
uint32_t
*
address_as_ui
=
(
uint32_t
*
)
((
char
*
)
address
-
((
size_t
)
address
&
2
));
uint32_t
old
=
*
address_as_ui
;
uint32_t
sum
;
uint32_t
newval
;
uint32_t
assumed
;
do
{
assumed
=
old
;
sum
=
val
+
(
size_t
)
address
&
2
?
T
(
old
>>
16
)
:
T
(
old
&
0xffff
);
newval
=
(
size_t
)
address
&
2
?
(
old
&
0xffff
)
|
(
sum
<<
16
)
:
(
old
&
0xffff0000
)
|
sum
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
newval
);
}
while
(
assumed
!=
old
);
}
};
template
<
typename
T
>
struct
AtomicMaxIntegerImpl
<
T
,
4
>
{
inline
__device__
void
operator
()(
T
*
address
,
T
val
)
{
uint32_t
*
address_as_ui
=
(
uint32_t
*
)
(
address
);
uint32_t
old
=
*
address_as_ui
;
uint32_t
newval
;
uint32_t
assumed
;
do
{
assumed
=
old
;
newval
=
val
+
(
T
)
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
newval
);
}
while
(
assumed
!=
old
);
}
};
template
<
typename
T
>
struct
AtomicMaxIntegerImpl
<
T
,
8
>
{
inline
__device__
void
operator
()(
T
*
address
,
T
val
)
{
unsigned
long
long
*
address_as_ui
=
(
unsigned
long
long
*
)
(
address
);
unsigned
long
long
old
=
*
address_as_ui
;
unsigned
long
long
newval
;
unsigned
long
long
assumed
;
do
{
assumed
=
old
;
newval
=
val
+
(
T
)
old
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
newval
);
}
while
(
assumed
!=
old
);
}
};
static
inline
__device__
void
atomicMax
(
uint8_t
*
address
,
uint8_t
val
)
{}
static
inline
__device__
void
atomicMax
(
int8_t
*
address
,
int8_t
val
)
{}
static
inline
__device__
void
atomicMax
(
int16_t
*
address
,
int16_t
val
)
{}
static
inline
__device__
void
atomicMax
(
int64_t
*
address
,
int64_t
val
)
{}
#ifdef CUDA_HALF_TENSOR
static
inline
__device__
void
atomicMax
(
half
*
address
,
half
val
)
{}
#endif
static
inline
__device__
void
atomicMax
(
float
*
address
,
float
val
)
{
int
*
address_as_i
=
(
int
*
)
address
;
int
old
=
*
address_as_i
;
int
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_i
,
assumed
,
__float_as_int
(
max
(
val
,
__int_as_float
(
assumed
))));
}
while
(
assumed
!=
old
);
}
static
inline
__device__
void
atomicMax
(
double
*
address
,
double
val
)
{
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
;
unsigned
long
long
int
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
max
(
val
,
__longlong_as_double
(
assumed
))));
}
while
(
assumed
!=
old
);
}
torch_scatter/kernel/THCTensorInfo.cuh
deleted
100644 → 0
View file @
b3091036
torch_scatter/kernel/common.cuh
View file @
6372815e
#define KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)
const
int
MAX_DIMS
=
25
;
const
int
MAX_DIMS
=
25
;
const
int
NUM_THREADS
=
1024
;
const
int
NUM_THREADS
=
1024
;
...
@@ -23,3 +20,18 @@ struct TensorInfo {
...
@@ -23,3 +20,18 @@ struct TensorInfo {
int
size
[
MAX_DIMS
];
int
size
[
MAX_DIMS
];
int
stride
[
MAX_DIMS
];
int
stride
[
MAX_DIMS
];
};
};
#define KERNEL_LOOP(I, N) \
for (int I = blockIdx.x * blockDim.x + threadIdx.x; I < N; i += blockDim.x * gridDim.x)
/* #define KERNEL_RUN(NAME, DIMS, N, PARAMS) \ */
#define KERNEL_RUN(NAME, DIMS, N, ...) \
int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \
switch (DIMS) { \
case 1: NAME<real, 1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 2: NAME<real, 2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 3: NAME<real, 3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
} \
THCudaCheck(cudaGetLastError());
torch_scatter/kernel/generic/kernel.cu
View file @
6372815e
...
@@ -24,9 +24,13 @@ void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor
...
@@ -24,9 +24,13 @@ void scatter_(max)(THCState *state, int dim, THCTensor *output, THCudaLongTensor
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
int64_t
>
arg
Output
Info
=
thc_getTensorInfo_Long
(
state
,
arg_output
);
TensorInfo
<
int64_t
>
argInfo
=
thc_getTensorInfo_Long
(
state
,
arg_output
);
maxKernel
<
real
,
-
1
><<<
GET_BLOCKS
(
n
),
NUM_THREADS
,
0
,
THCState_getCurrentStream
(
state
)
>>>
(
outputInfo
,
indexInfo
,
inputInfo
,
argOutputInfo
,
dim
,
n
);
KERNEL_RUN
(
maxKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
argInfo
,
dim
)
/* KERNEL_RUN(argKernel, indexInfo.dims, n, outputInfo, indexInfo, dim) */
/* maxKernel<real, -1><<<GET_BLOCKS(n), NUM_THREADS, 0, THCState_getCurrentStream(state)>>>(outputInfo, indexInfo, inputInfo, dim, n); */
/* argKernel<real, -1><<<GET_BLOCKS(n), NUM_THREADS, 0, THCState_getCurrentStream(state)>>>(dim, n); */
}
}
void
scatter_
(
min
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg_output
)
{
void
scatter_
(
min
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg_output
)
{
...
...
torch_scatter/kernel/kernel.cu
View file @
6372815e
#include <THC/THC.h>
#include <THC/THC.h>
#include "THCAtomics.cuh"
#include "kernel.h"
#include "kernel.h"
#include "common.cuh"
#include "common.cuh"
...
@@ -13,9 +14,27 @@
...
@@ -13,9 +14,27 @@
#include "THCGenerateAllTypes.h"
#include "THCGenerateAllTypes.h"
template
<
typename
Real
,
int
Dims
>
template
<
typename
Real
,
int
Dims
>
__global__
void
maxKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
int64_t
>
arg
_output
,
const
int
dim
,
const
int
n
)
{
__global__
void
maxKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
int64_t
>
arg
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
int
argOffset
=
0
;
int
curDimIndex
;
for
(
int
d
=
index
.
dims
-
1
;
d
>=
0
;
d
--
)
{
curDimIndex
=
i
%
index
.
size
[
d
];
indexOffset
+=
curDimIndex
*
index
.
stride
[
d
];
inputOffset
+=
curDimIndex
*
input
.
stride
[
d
];
if
(
d
!=
dim
)
{
outputOffset
+=
curDimIndex
*
output
.
stride
[
d
];
argOffset
+=
curDimIndex
*
arg
.
stride
[
d
];
}
i
/=
index
.
size
[
d
];
}
int64_t
indexValue
=
index
.
data
[
indexOffset
];
assert
(
indexValue
>=
0
&&
indexValue
<
output
.
size
[
dim
]);
outputOffset
+=
indexValue
*
output
.
stride
[
dim
];
argOffset
+=
indexValue
*
arg
.
stride
[
dim
];
atomicMax
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
// TODO: Do something with arg.
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment