Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
aaa5a410
Commit
aaa5a410
authored
Dec 20, 2017
by
rusty1s
Browse files
integer atomic
parent
cc561ac4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
19 deletions
+35
-19
torch_scatter/kernel/THCAtomics.cuh
torch_scatter/kernel/THCAtomics.cuh
+33
-17
torch_scatter/kernel/common.cuh
torch_scatter/kernel/common.cuh
+1
-1
torch_scatter/kernel/kernel.cu
torch_scatter/kernel/kernel.cu
+1
-1
No files found.
torch_scatter/kernel/THCAtomics.cuh
View file @
aaa5a410
#define OP(X, Y) max(X, Y)
template
<
typename
T
,
size_t
n
>
struct
Atomic
Max
IntegerImpl
;
struct
AtomicIntegerImpl
;
template
<
typename
T
>
struct
Atomic
Max
IntegerImpl
<
T
,
1
>
{
struct
AtomicIntegerImpl
<
T
,
1
>
{
inline
__device__
void
operator
()(
T
*
address
,
T
val
)
{
uint32_t
*
address_as_ui
=
(
uint32_t
*
)
(
address
-
((
size_t
)
address
&
3
));
uint32_t
old
=
*
address_as_ui
;
uint32_t
shift
=
(((
size_t
)
address
&
3
)
*
8
);
uint32_t
sum
;
uint32_t
res
;
uint32_t
assumed
;
do
{
assumed
=
old
;
sum
=
max
(
val
,
T
((
old
>>
shift
)
&
0xff
));
old
=
(
old
&
~
(
0x000000ff
<<
shift
))
|
(
sum
<<
shift
);
res
=
OP
(
val
,
T
((
old
>>
shift
)
&
0xff
));
old
=
(
old
&
~
(
0x000000ff
<<
shift
))
|
(
res
<<
shift
);
old
=
atomicCAS
(
address_as_ui
,
assumed
,
old
);
}
while
(
assumed
!=
old
);
}
};
template
<
typename
T
>
struct
Atomic
Max
IntegerImpl
<
T
,
2
>
{
struct
AtomicIntegerImpl
<
T
,
2
>
{
inline
__device__
void
operator
()(
T
*
address
,
T
val
)
{
uint32_t
*
address_as_ui
=
(
uint32_t
*
)
((
char
*
)
address
-
((
size_t
)
address
&
2
));
uint32_t
old
=
*
address_as_ui
;
uint32_t
sum
;
uint32_t
res
;
uint32_t
newval
;
uint32_t
assumed
;
do
{
assumed
=
old
;
sum
=
max
(
val
,
(
size_t
)
address
&
2
?
T
(
old
>>
16
)
:
T
(
old
&
0xffff
));
newval
=
(
size_t
)
address
&
2
?
(
old
&
0xffff
)
|
(
sum
<<
16
)
:
(
old
&
0xffff0000
)
|
sum
;
res
=
OP
(
val
,
(
size_t
)
address
&
2
?
T
(
old
>>
16
)
:
T
(
old
&
0xffff
));
newval
=
(
size_t
)
address
&
2
?
(
old
&
0xffff
)
|
(
res
<<
16
)
:
(
old
&
0xffff0000
)
|
res
;
old
=
atomicCAS
(
address_as_ui
,
assumed
,
newval
);
}
while
(
assumed
!=
old
);
}
};
template
<
typename
T
>
struct
AtomicMaxIntegerImpl
<
T
,
8
>
{
struct
AtomicIntegerImpl
<
T
,
4
>
{
inline
__device__
void
operator
()(
T
*
address
,
T
val
)
{
uint32_t
*
address_as_ull
=
(
uint32_t
*
)
(
address
);
uint32_t
old
=
*
address_as_ull
;
uint32_t
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
OP
(
val
,
(
T
)
old
));
}
while
(
assumed
!=
old
);
}
};
template
<
typename
T
>
struct
AtomicIntegerImpl
<
T
,
8
>
{
inline
__device__
void
operator
()(
T
*
address
,
T
val
)
{
unsigned
long
long
*
address_as_ull
=
(
unsigned
long
long
*
)
(
address
);
unsigned
long
long
old
=
*
address_as_ull
;
...
...
@@ -46,25 +62,25 @@ struct AtomicMaxIntegerImpl<T, 8> {
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
max
(
val
,
(
T
)
old
));
old
=
atomicCAS
(
address_as_ull
,
assumed
,
OP
(
val
,
(
T
)
old
));
}
while
(
assumed
!=
old
);
}
};
static
inline
__device__
void
atomicMax
(
uint8_t
*
address
,
uint8_t
val
)
{
Atomic
Max
IntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
AtomicIntegerImpl
<
uint8_t
,
sizeof
(
uint8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMax
(
int8_t
*
address
,
int8_t
val
)
{
Atomic
Max
IntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
AtomicIntegerImpl
<
int8_t
,
sizeof
(
int8_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMax
(
int16_t
*
address
,
int16_t
val
)
{
Atomic
Max
IntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
AtomicIntegerImpl
<
int16_t
,
sizeof
(
int16_t
)
>
()(
address
,
val
);
}
static
inline
__device__
void
atomicMax
(
int64_t
*
address
,
int64_t
val
)
{
Atomic
Max
IntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
AtomicIntegerImpl
<
int64_t
,
sizeof
(
int64_t
)
>
()(
address
,
val
);
}
#ifdef CUDA_HALF_TENSOR
...
...
@@ -78,7 +94,7 @@ static inline __device__ void atomicMax(float *address, float val) {
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_i
,
assumed
,
__float_as_int
(
max
(
val
,
__int_as_float
(
assumed
))));
old
=
atomicCAS
(
address_as_i
,
assumed
,
__float_as_int
(
OP
(
val
,
__int_as_float
(
assumed
))));
}
while
(
assumed
!=
old
);
}
...
...
@@ -89,6 +105,6 @@ static inline __device__ void atomicMax(double *address, double val) {
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
max
(
val
,
__longlong_as_double
(
assumed
))));
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
OP
(
val
,
__longlong_as_double
(
assumed
))));
}
while
(
assumed
!=
old
);
}
torch_scatter/kernel/common.cuh
View file @
aaa5a410
...
...
@@ -5,7 +5,7 @@ inline int GET_BLOCKS(const int n) {
return
(
n
+
NUM_THREADS
-
1
)
/
NUM_THREADS
;
}
template
<
typename
T
>
template
<
typename
T
>
struct
TensorInfo
{
TensorInfo
(
T
*
t
,
int
d
,
int
sz
[
MAX_DIMS
],
int
st
[
MAX_DIMS
])
{
data
=
t
;
dims
=
d
;
...
...
torch_scatter/kernel/kernel.cu
View file @
aaa5a410
...
...
@@ -13,7 +13,7 @@
#include "generic/common.cu"
#include "THCGenerateAllTypes.h"
template
<
typename
Real
,
int
Dims
>
template
<
typename
Real
,
int
Dims
>
__global__
void
maxKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
int64_t
>
arg
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
int
argOffset
=
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment