Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
3f97baae
Commit
3f97baae
authored
Mar 30, 2018
by
rusty1s
Browse files
gridKernel impl
parent
5d7997a0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
86 additions
and
58 deletions
+86
-58
aten/TH/generic/THGrid.c
aten/TH/generic/THGrid.c
+1
-1
aten/THC/THCGrid.cu
aten/THC/THCGrid.cu
+14
-0
aten/THC/common.h
aten/THC/common.h
+34
-0
aten/THC/generic/THCGrid.cu
aten/THC/generic/THCGrid.cu
+10
-1
benchmark/benchmark.py
benchmark/benchmark.py
+0
-31
benchmark/bernoulli.py
benchmark/bernoulli.py
+0
-19
benchmark/ffi.py
benchmark/ffi.py
+23
-0
build.py
build.py
+4
-4
torch_cluster/kernel/serial.h
torch_cluster/kernel/serial.h
+0
-2
No files found.
aten/TH/generic/THGrid.c
View file @
3f97baae
...
@@ -12,7 +12,7 @@ void THGrid_(THLongTensor *cluster, THTensor *pos, THTensor *size, THLongTensor
...
@@ -12,7 +12,7 @@ void THGrid_(THLongTensor *cluster, THTensor *pos, THTensor *size, THLongTensor
for
(
n
=
0
;
n
<
THTensor_
(
size
)(
pos
,
0
);
n
++
)
{
for
(
n
=
0
;
n
<
THTensor_
(
size
)(
pos
,
0
);
n
++
)
{
coef
=
1
;
value
=
0
;
coef
=
1
;
value
=
0
;
for
(
d
=
0
;
d
<
THTensor_
(
size
)(
pos
,
1
);
d
++
)
{
for
(
d
=
0
;
d
<
THTensor_
(
size
)(
pos
,
1
);
d
++
)
{
value
+=
coef
*
(
int64_t
)
(
*
(
posData
+
d
*
pos
->
stride
[
1
]
)
/
sizeData
[
d
]);
value
+=
coef
*
(
int64_t
)
(
posData
[
d
*
pos
->
stride
[
1
]
]
/
sizeData
[
d
]);
coef
*=
countData
[
d
];
coef
*=
countData
[
d
];
}
}
posData
+=
pos
->
stride
[
0
];
posData
+=
pos
->
stride
[
0
];
...
...
aten/THC/THCGrid.cu
View file @
3f97baae
#include "THCGrid.h"
#include "THCGrid.h"
template
<
typename
real
,
int
dims
>
__global__
void
gridKernel
(
int64_t
*
cluster
,
TensorInfo
<
real
>
posInfo
,
real
*
size
,
int64_t
*
count
,
const
int
nNodes
)
{
KERNEL_LOOP
(
i
,
nNodes
)
{
real
*
pos
=
posInfo
->
data
+
i
*
posInfo
->
stride
[
0
];
int64_t
coef
=
1
,
value
=
0
;
for
(
ptrdiff_t
d
=
0
;
d
<
dims
;
d
++
)
{
value
+=
coef
*
(
int64_t
)
(
pos
[
d
*
posInfo
->
stride
[
1
]]
/
size
[
d
]);
coef
*=
count
[
d
];
}
cluster
[
i
]
=
value
;
}
}
#include "generic/THCGrid.cu"
#include "generic/THCGrid.cu"
#include "THC/THCGenerateAllTypes.h"
#include "THC/THCGenerateAllTypes.h"
aten/THC/common.h
0 → 100644
View file @
3f97baae
#ifndef THC_COMMON_INC
#define THC_COMMON_INC
#define KERNEL_LOOP(I, N) \
for (ptrdiff_t I = blockIdx.x * blockDim.x + threadIdx.x; I < I; I += blockDim.x * gridDim.x)
#define THC_assertSameGPU(...) THAssertMsg(THCTensor_(checkGPU)(__VA_ARGS__), \
"Some of the input tensors are located on different GPUs. Please move them to a single one.")
const
int
CUDA_NUM_THREADS
=
1024
;
inline
int
GET_BLOCKS
(
const
int
N
)
{
return
(
N
+
CUDA_NUM_THREADS
-
1
)
/
CUDA_NUM_THREADS
;
}
#define KERNEL_RUN(NAME, N, ...) \
int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \
NAME<real><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \
THCudaCheck(cudaGetLastError())
#define FIXED_DIM_KERNEL_RUN(NAME, N, DIMS, ...) \
int grid = GET_BLOCKS(N); \
cudaStream_t stream = THCState_getCurrentStream(state); \
switch (DIMS) { \
case 1: NAME<real, 1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 2: NAME<real, 2><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 3: NAME<real, 3><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
case 4: NAME<real, 4><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); break; \
default: NAME<real, -1><<<grid, NUM_THREADS, 0, stream>>>(__VA_ARGS__, N); \
} \
THCudaCheck(cudaGetLastError())
#endif // THC_COMMON_INC
aten/THC/generic/THCGrid.cu
View file @
3f97baae
...
@@ -4,7 +4,16 @@
...
@@ -4,7 +4,16 @@
void
THCGrid_
(
THCState
*
state
,
THCudaLongTensor
*
cluster
,
THCTensor
*
pos
,
THCTensor
*
size
,
void
THCGrid_
(
THCState
*
state
,
THCudaLongTensor
*
cluster
,
THCTensor
*
pos
,
THCTensor
*
size
,
THCudaLongTensor
*
count
)
{
THCudaLongTensor
*
count
)
{
printf
(
"THCGrid drin"
);
THC_assertSameGPU
(
state
,
4
,
cluster
,
pos
,
size
,
count
);
int64_t
*
clusterData
=
THCudaLongTensor_data
(
state
,
cluster
);
TensorInfo
<
real
>
posInfo
=
THC_
(
getTensorInfo
)(
state
,
pos
);
real
*
sizeData
=
THCTensor_
(
data
)(
state
,
size
);
int64_t
*
countData
=
THCudaLongTensor_data
(
state
,
count
);
const
int
nNodes
=
THCudaLongTensor_nElement
(
state
,
cluster
);
const
int
dims
=
THCTensor_
(
nElement
)(
size
);
FIXED_DIM_KERNEL_RUN
(
gridKernel
,
nNodes
,
dims
,
clusterData
,
posInfo
,
sizeData
,
countData
);
}
}
#endif // THC_GENERIC_FILE
#endif // THC_GENERIC_FILE
benchmark/benchmark.py
deleted
100644 → 0
View file @
5d7997a0
import
time
import
torch
from
torch_cluster
import
sparse_grid_cluster
n
=
90000000
s
=
1
/
64
print
(
'GPU ==================='
)
t
=
time
.
perf_counter
()
pos
=
torch
.
cuda
.
FloatTensor
(
n
,
3
).
uniform_
(
0
,
1
)
size
=
torch
.
cuda
.
FloatTensor
([
s
,
s
,
s
])
torch
.
cuda
.
synchronize
()
print
(
'Init:'
,
time
.
perf_counter
()
-
t
)
t_all
=
time
.
perf_counter
()
sparse_grid_cluster
(
pos
,
size
)
torch
.
cuda
.
synchronize
()
t_all
=
time
.
perf_counter
()
-
t_all
print
(
'All:'
,
t_all
)
print
(
'CPU ==================='
)
pos
=
pos
.
cpu
()
size
=
size
.
cpu
()
t_all
=
time
.
perf_counter
()
sparse_grid_cluster
(
pos
,
size
)
t_all
=
time
.
perf_counter
()
-
t_all
print
(
'All:'
,
t_all
)
benchmark/bernoulli.py
deleted
100644 → 0
View file @
5d7997a0
import
time
import
torch
from
torch_cluster.functions.utils.ffi
import
_get_func
output
=
torch
.
cuda
.
FloatTensor
(
500000000
).
fill_
(
0.5
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
torch
.
bernoulli
(
output
)
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
output
=
output
.
long
().
fill_
(
-
1
)
func
=
_get_func
(
'serial'
,
output
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
func
(
output
,
output
,
output
,
output
)
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
benchmark/ffi.py
0 → 100644
View file @
3f97baae
import
torch
from
torch_cluster._ext
import
ffi
print
(
ffi
.
__dict__
)
print
(
ffi
.
THByteGrid
)
cluster
=
torch
.
LongTensor
(
5
)
pos
=
torch
.
Tensor
([[
1
,
1
],
[
1
,
1
],
[
3
,
3
],
[
4
,
4
],
[
3
,
3
]])
size
=
torch
.
Tensor
([
2
,
2
])
count
=
torch
.
LongTensor
([
3
,
3
])
ffi
.
THFloatGrid
(
cluster
,
pos
,
size
,
count
)
print
(
cluster
)
cluster
=
torch
.
LongTensor
(
3
)
row
=
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
2
,
2
])
col
=
torch
.
LongTensor
([
1
,
2
,
0
,
2
,
0
,
1
])
deg
=
torch
.
LongTensor
([
2
,
2
,
2
])
weight
=
torch
.
Tensor
([
1
,
2
,
1
,
1
,
2
,
1
])
ffi
.
THFloatGreedy
(
cluster
,
row
,
col
,
deg
,
weight
)
print
(
cluster
)
build.py
View file @
3f97baae
...
@@ -8,11 +8,11 @@ from torch.utils.ffi import create_extension
...
@@ -8,11 +8,11 @@ from torch.utils.ffi import create_extension
if
osp
.
exists
(
'build'
):
if
osp
.
exists
(
'build'
):
shutil
.
rmtree
(
'build'
)
shutil
.
rmtree
(
'build'
)
files
=
[
'
serial
'
,
'
g
rid'
]
files
=
[
'
Greedy
'
,
'
G
rid'
]
headers
=
[
'
torch_cluster/src/{}_cpu
.h'
.
format
(
f
)
for
f
in
files
]
headers
=
[
'
aten/TH/TH{}
.h'
.
format
(
f
)
for
f
in
files
]
sources
=
[
'
torch_cluster/src/{}_cpu
.c'
.
format
(
f
)
for
f
in
files
]
sources
=
[
'
aten/TH/TH{}
.c'
.
format
(
f
)
for
f
in
files
]
include_dirs
=
[
'torch_cluster/src'
,
'aten/TH'
]
include_dirs
=
[
'aten/TH'
]
define_macros
=
[]
define_macros
=
[]
extra_objects
=
[]
extra_objects
=
[]
with_cuda
=
False
with_cuda
=
False
...
...
torch_cluster/kernel/serial.h
View file @
3f97baae
...
@@ -2,8 +2,6 @@
...
@@ -2,8 +2,6 @@
extern
"C"
{
extern
"C"
{
#endif
#endif
int
assignColor
(
THCState
*
state
,
THCudaLongTensor
*
color
);
void
cluster_serial_kernel
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
);
void
cluster_serial_kernel
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
);
void
cluster_serial_kernel_Float
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaTensor
*
weight
);
void
cluster_serial_kernel_Float
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaTensor
*
weight
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment