Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
8b5daa16
Commit
8b5daa16
authored
Mar 30, 2018
by
rusty1s
Browse files
gridkernel done
parent
ad63397e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
72 additions
and
25 deletions
+72
-25
aten/THC/THCGrid.cu
aten/THC/THCGrid.cu
+8
-5
aten/THC/THCNumerics.cuh
aten/THC/THCNumerics.cuh
+27
-0
aten/THC/common.h
aten/THC/common.h
+16
-14
aten/THC/generic/THCGrid.cu
aten/THC/generic/THCGrid.cu
+2
-3
aten/THC/generic/common.h
aten/THC/generic/common.h
+16
-0
benchmark/ffi.py
benchmark/ffi.py
+3
-3
No files found.
aten/THC/THCGrid.cu
View file @
8b5daa16
#include "THCGrid.h"
#include "THCGrid.h"
template
<
typename
real
,
int
dims
>
#include "common.h"
__global__
void
gridKernel
(
int64_t
*
cluster
,
TensorInfo
<
real
>
posInfo
,
real
*
size
,
#include "THCNumerics.cuh"
template
<
typename
T
>
__global__
void
gridKernel
(
int64_t
*
cluster
,
TensorInfo
<
T
>
posInfo
,
T
*
size
,
int64_t
*
count
,
const
int
nNodes
)
{
int64_t
*
count
,
const
int
nNodes
)
{
KERNEL_LOOP
(
i
,
nNodes
)
{
KERNEL_LOOP
(
i
,
nNodes
)
{
real
*
pos
=
posInfo
->
data
+
i
*
posInfo
->
stride
[
0
];
T
*
pos
=
posInfo
.
data
+
i
*
posInfo
.
stride
[
0
];
int64_t
coef
=
1
,
value
=
0
;
int64_t
coef
=
1
,
value
=
0
;
for
(
ptrdiff_t
d
=
0
;
d
<
dims
;
d
++
)
{
for
(
ptrdiff_t
d
=
0
;
d
<
posInfo
.
dims
*
posInfo
.
stride
[
1
];
d
+=
posInfo
.
stride
[
1
]
)
{
value
+=
coef
*
(
int64_t
)
(
pos
[
d
*
posInfo
->
stride
[
1
]]
/
size
[
d
]);
value
+=
coef
*
THCNumerics
<
T
>::
floor
(
THCNumerics
<
T
>::
div
(
pos
[
d
],
size
[
d
])
)
;
coef
*=
count
[
d
];
coef
*=
count
[
d
];
}
}
cluster
[
i
]
=
value
;
cluster
[
i
]
=
value
;
...
...
aten/THC/THCNumerics.cuh
0 → 100644
View file @
8b5daa16
#ifndef THC_NUMERICS_INC
#define THC_NUMERICS_INC
#include "THC/THCHalf.h"
template
<
typename
T
>
struct
THCNumerics
{
static
inline
__host__
__device__
T
div
(
T
a
,
T
b
)
{
return
a
/
b
;
}
static
inline
__host__
__device__
int
floor
(
T
a
)
{
return
a
;
}
};
#ifdef CUDA_HALF_TENSOR
#ifdef __CUDA_ARCH__
#define h2f(A) __half2float(A)
#define f2h(A) __float2half(A)
#else // CUDA_ARCH__
#define h2f(A) THC_half2float(A)
#define f2h(A) THC_float2half(A)
#endif
template
<
>
struct
THCNumerics
<
half
>
{
static
inline
__host__
__device__
half
div
(
half
a
,
half
b
)
{
return
f2h
(
h2f
(
a
)
/
h2f
(
b
));
}
static
inline
__host__
__device__
int
floor
(
half
a
)
{
return
(
int
)
h2f
(
a
);
}
};
#endif // CUDA_HALF_TENSOR
#endif // THC_NUMERICS_INC
aten/THC/common.h
View file @
8b5daa16
#ifndef THC_COMMON_INC
#ifndef THC_COMMON_INC
#define THC_COMMON_INC
#define THC_COMMON_INC
#define THCTensor_(NAME) TH_CONCAT_4(TH,CReal,Tensor_,NAME)
#define KERNEL_LOOP(I, N) \
#define KERNEL_LOOP(I, N) \
for (ptrdiff_t I = blockIdx.x * blockDim.x + threadIdx.x; I <
I
; I += blockDim.x * gridDim.x)
for (ptrdiff_t I = blockIdx.x * blockDim.x + threadIdx.x; I <
N
; I += blockDim.x * gridDim.x)
#define THC_assertSameGPU(...) THAssertMsg(THCTensor_(checkGPU)(__VA_ARGS__), \
#define THC_assertSameGPU(...) THAssertMsg(THCTensor_(checkGPU)(__VA_ARGS__), \
"Some of the input tensors are located on different GPUs. Please move them to a single one.")
"Some of the input tensors are located on different GPUs. Please move them to a single one.")
const
int
CUDA_NUM_THREADS
=
1024
;
const
int
MAX_DIMS
=
25
;
const
int
NUM_THREADS
=
1024
;
inline
int
GET_BLOCKS
(
const
int
N
)
{
inline
int
GET_BLOCKS
(
const
int
N
)
{
return
(
N
+
CUDA_
NUM_THREADS
-
1
)
/
CUDA_
NUM_THREADS
;
return
(
N
+
NUM_THREADS
-
1
)
/
NUM_THREADS
;
}
}
#define KERNEL_RUN(NAME, N, ...) \
#define KERNEL_RUN(NAME, N, ...) \
...
@@ -19,16 +22,15 @@ inline int GET_BLOCKS(const int N) {
...
@@ -19,16 +22,15 @@ inline int GET_BLOCKS(const int N) {
NAME<real><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \
NAME<real><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \
THCudaCheck(cudaGetLastError())
THCudaCheck(cudaGetLastError())
#define FIXED_DIM_KERNEL_RUN(NAME, N, DIMS, ...) \
template
<
typename
T
>
int grid = GET_BLOCKS(N); \
struct
TensorInfo
{
cudaStream_t stream = THCState_getCurrentStream(state); \
T
*
data
;
switch (DIMS) { \
int
dims
;
case 1: NAME<real, 1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
int
size
[
MAX_DIMS
];
case 2: NAME<real, 2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
int
stride
[
MAX_DIMS
];
case 3: NAME<real, 3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
};
case 4: NAME<real, 4><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \
#include "generic/common.h"
} \
#include "THC/THCGenerateAllTypes.h"
THCudaCheck(cudaGetLastError())
#endif // THC_COMMON_INC
#endif // THC_COMMON_INC
aten/THC/generic/THCGrid.cu
View file @
8b5daa16
...
@@ -7,13 +7,12 @@ void THCGrid_(THCState *state, THCudaLongTensor *cluster, THCTensor *pos, THCTen
...
@@ -7,13 +7,12 @@ void THCGrid_(THCState *state, THCudaLongTensor *cluster, THCTensor *pos, THCTen
THC_assertSameGPU
(
state
,
4
,
cluster
,
pos
,
size
,
count
);
THC_assertSameGPU
(
state
,
4
,
cluster
,
pos
,
size
,
count
);
int64_t
*
clusterData
=
THCudaLongTensor_data
(
state
,
cluster
);
int64_t
*
clusterData
=
THCudaLongTensor_data
(
state
,
cluster
);
TensorInfo
<
real
>
posInfo
=
THC_
(
getTensorInfo
)(
state
,
pos
);
TensorInfo
<
real
>
posInfo
=
THC
Tensor
_
(
getTensorInfo
)(
state
,
pos
);
real
*
sizeData
=
THCTensor_
(
data
)(
state
,
size
);
real
*
sizeData
=
THCTensor_
(
data
)(
state
,
size
);
int64_t
*
countData
=
THCudaLongTensor_data
(
state
,
count
);
int64_t
*
countData
=
THCudaLongTensor_data
(
state
,
count
);
const
int
nNodes
=
THCudaLongTensor_nElement
(
state
,
cluster
);
const
int
nNodes
=
THCudaLongTensor_nElement
(
state
,
cluster
);
const
int
dims
=
THCTensor_
(
nElement
)(
size
);
KERNEL_RUN
(
gridKernel
,
nNodes
,
clusterData
,
posInfo
,
sizeData
,
countData
);
FIXED_DIM_KERNEL_RUN
(
gridKernel
,
nNodes
,
dims
,
clusterData
,
posInfo
,
sizeData
,
countData
);
}
}
#endif // THC_GENERIC_FILE
#endif // THC_GENERIC_FILE
aten/THC/generic/common.h
0 → 100644
View file @
8b5daa16
#ifndef THC_GENERIC_FILE
#define THC_GENERIC_FILE "generic/common.h"
#else
TensorInfo
<
real
>
THCTensor_
(
getTensorInfo
)(
THCState
*
state
,
THCTensor
*
tensor
)
{
TensorInfo
<
real
>
tensorInfo
=
TensorInfo
<
real
>
();
tensorInfo
.
data
=
THCTensor_
(
data
)(
state
,
tensor
);
tensorInfo
.
dims
=
THCTensor_
(
nDimension
)(
state
,
tensor
);
for
(
ptrdiff_t
d
=
0
;
d
<
tensorInfo
.
dims
;
d
++
)
{
tensorInfo
.
size
[
d
]
=
THCTensor_
(
size
)(
state
,
tensor
,
d
);
tensorInfo
.
stride
[
d
]
=
THCTensor_
(
stride
)(
state
,
tensor
,
d
);
}
return
tensorInfo
;
}
#endif // THC_GENERIC_FILE
benchmark/ffi.py
View file @
8b5daa16
...
@@ -2,9 +2,9 @@ import torch
...
@@ -2,9 +2,9 @@ import torch
from
torch_cluster._ext
import
ffi
from
torch_cluster._ext
import
ffi
cluster
=
torch
.
cuda
.
LongTensor
(
5
)
cluster
=
torch
.
cuda
.
LongTensor
(
5
)
pos
=
torch
.
cuda
.
FloatTensor
(
5
,
2
)
pos
=
torch
.
cuda
.
FloatTensor
(
[[
1
,
1
],
[
3
,
3
],
[
1
,
1
],
[
5
,
5
],
[
3
,
3
]]
)
size
=
torch
.
cuda
.
FloatTensor
(
2
)
size
=
torch
.
cuda
.
FloatTensor
(
[
2
,
2
]
)
count
=
torch
.
cuda
.
LongTensor
(
2
)
count
=
torch
.
cuda
.
LongTensor
(
[
3
,
3
]
)
func
=
ffi
.
THCCFloatGrid
func
=
ffi
.
THCCFloatGrid
print
(
func
)
print
(
func
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment