Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
eace3488
Commit
eace3488
authored
Feb 08, 2018
by
rusty1s
Browse files
performance boost, however, not finished yet
parent
37778e99
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
85 additions
and
70 deletions
+85
-70
setup.py
setup.py
+1
-1
torch_cluster/__init__.py
torch_cluster/__init__.py
+1
-1
torch_cluster/functions/grid.py
torch_cluster/functions/grid.py
+11
-14
torch_cluster/functions/utils.py
torch_cluster/functions/utils.py
+4
-4
torch_cluster/kernel/generic/kernel.cu
torch_cluster/kernel/generic/kernel.cu
+18
-10
torch_cluster/kernel/kernel.cu
torch_cluster/kernel/kernel.cu
+9
-6
torch_cluster/kernel/kernel.h
torch_cluster/kernel/kernel.h
+7
-7
torch_cluster/src/cpu.h
torch_cluster/src/cpu.h
+7
-7
torch_cluster/src/cuda.h
torch_cluster/src/cuda.h
+7
-7
torch_cluster/src/generic/cpu.c
torch_cluster/src/generic/cpu.c
+18
-11
torch_cluster/src/generic/cuda.c
torch_cluster/src/generic/cuda.c
+2
-2
No files found.
setup.py
View file @
eace3488
...
@@ -2,7 +2,7 @@ from os import path as osp
...
@@ -2,7 +2,7 @@ from os import path as osp
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
__version__
=
'0.1.
1
'
__version__
=
'0.1.
2
'
url
=
'https://github.com/rusty1s/pytorch_cluster'
url
=
'https://github.com/rusty1s/pytorch_cluster'
install_requires
=
[
'cffi'
,
'torch-unique'
]
install_requires
=
[
'cffi'
,
'torch-unique'
]
...
...
torch_cluster/__init__.py
View file @
eace3488
from
.functions.grid
import
grid_cluster
from
.functions.grid
import
grid_cluster
__version__
=
'0.1.
1
'
__version__
=
'0.1.
2
'
__all__
=
[
'grid_cluster'
,
'__version__'
]
__all__
=
[
'grid_cluster'
,
'__version__'
]
torch_cluster/functions/grid.py
View file @
eace3488
...
@@ -22,31 +22,28 @@ def grid_cluster(position, size, batch=None):
...
@@ -22,31 +22,28 @@ def grid_cluster(position, size, batch=None):
size
=
torch
.
cat
([
size
.
new
(
1
).
fill_
(
1
),
size
],
dim
=-
1
)
size
=
torch
.
cat
([
size
.
new
(
1
).
fill_
(
1
),
size
],
dim
=-
1
)
# Translate to minimal positive positions.
# Translate to minimal positive positions.
min
=
position
.
min
(
dim
=-
2
,
keepdim
=
True
)[
0
]
p_
min
=
position
.
min
(
dim
=-
2
,
keepdim
=
True
)[
0
]
position
=
position
-
min
position
=
position
-
p_
min
# Compute cluster count for each dimension.
# Compute maximal position for each dimension.
max
=
position
.
max
(
dim
=
0
)[
0
]
p_max
=
position
.
max
(
dim
=
0
)[
0
]
while
max
.
dim
()
>
1
:
while
p_max
.
dim
()
>
1
:
max
=
max
.
max
(
dim
=
0
)[
0
]
p_max
=
p_max
.
max
(
dim
=
0
)[
0
]
c_max
=
torch
.
floor
(
max
.
double
()
/
size
.
double
()
+
1
).
long
()
c_max
=
torch
.
clamp
(
c_max
,
min
=
1
)
C
=
c_max
.
prod
()
# Generate cluster tensor.
# Generate cluster tensor.
s
=
list
(
position
.
size
())
s
=
list
(
position
.
size
())[:
-
1
]
+
[
1
]
s
[
-
1
]
=
1
cluster
=
size
.
new
(
torch
.
Size
(
s
)).
long
()
cluster
=
c_max
.
new
(
torch
.
Size
(
s
))
# Fill cluster tensor and reshape.
# Fill cluster tensor and reshape.
size
=
size
.
type_as
(
position
)
size
=
size
.
type_as
(
position
)
func
=
get_func
(
'grid'
,
position
)
func
=
get_func
(
'grid'
,
position
)
func
(
C
,
cluster
,
position
,
size
,
c
_max
)
C
=
func
(
cluster
,
position
,
size
,
p
_max
)
cluster
=
cluster
.
squeeze
(
dim
=-
1
)
cluster
=
cluster
.
squeeze
(
dim
=-
1
)
cluster
,
u
=
consecutive
(
cluster
)
cluster
,
u
=
consecutive
(
cluster
)
if
batch
is
None
:
if
batch
is
None
:
return
cluster
return
cluster
else
:
else
:
batch
=
(
u
/
c_max
[
1
:].
prod
()).
long
()
print
(
p_max
.
tolist
(),
size
.
tolist
(),
C
)
batch
=
(
u
/
C
).
long
()
return
cluster
,
batch
return
cluster
,
batch
torch_cluster/functions/utils.py
View file @
eace3488
...
@@ -11,12 +11,12 @@ def get_func(name, tensor):
...
@@ -11,12 +11,12 @@ def get_func(name, tensor):
return
func
return
func
def
get_type
(
max
,
cuda
):
def
get_type
(
max
_value
,
cuda
):
if
max
<=
255
:
if
max
_value
<=
255
:
return
torch
.
cuda
.
ByteTensor
if
cuda
else
torch
.
ByteTensor
return
torch
.
cuda
.
ByteTensor
if
cuda
else
torch
.
ByteTensor
elif
max
<=
32767
:
# pragma: no cover
elif
max
_value
<=
32767
:
# pragma: no cover
return
torch
.
cuda
.
ShortTensor
if
cuda
else
torch
.
ShortTensor
return
torch
.
cuda
.
ShortTensor
if
cuda
else
torch
.
ShortTensor
elif
max
<=
2147483647
:
# pragma: no cover
elif
max
_value
<=
2147483647
:
# pragma: no cover
return
torch
.
cuda
.
IntTensor
if
cuda
else
torch
.
IntTensor
return
torch
.
cuda
.
IntTensor
if
cuda
else
torch
.
IntTensor
else
:
# pragma: no cover
else
:
# pragma: no cover
return
torch
.
cuda
.
LongTensor
if
cuda
else
torch
.
LongTensor
return
torch
.
cuda
.
LongTensor
if
cuda
else
torch
.
LongTensor
...
...
torch_cluster/kernel/generic/kernel.cu
View file @
eace3488
...
@@ -2,29 +2,37 @@
...
@@ -2,29 +2,37 @@
#define THC_GENERIC_FILE "generic/kernel.cu"
#define THC_GENERIC_FILE "generic/kernel.cu"
#else
#else
void
cluster_
(
grid
)(
THCState
*
state
,
int
C
,
THCudaLongTensor
*
output
,
THCTensor
*
position
,
THCTensor
*
size
,
THC
udaLong
Tensor
*
count
)
{
int64_t
cluster_
(
grid
)(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCTensor
*
position
,
THCTensor
*
size
,
THCTensor
*
maxPosition
)
{
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
2
,
position
,
size
));
THCAssertSameGPU
(
THCTensor_
(
checkGPU
)(
state
,
3
,
position
,
size
,
maxPosition
));
THCAssertSameGPU
(
THCudaLongTensor_checkGPU
(
state
,
2
,
output
,
count
));
THCAssertSameGPU
(
THCudaLongTensor_checkGPU
(
state
,
1
,
output
));
THArgCheck
(
THC
udaLong
Tensor_nDimension
(
state
,
output
)
<=
MAX_DIMS
,
1
,
"Tensor too large or too many dimensions"
);
THArgCheck
(
THCTensor_
(
nDimension
)
(
state
,
position
)
<=
MAX_DIMS
,
1
,
"Tensor too large or too many dimensions"
);
int64_t
*
outputData
=
THCudaLongTensor_data
(
state
,
output
);
int64_t
*
outputData
=
THCudaLongTensor_data
(
state
,
output
);
TensorInfo
<
real
>
positionInfo
=
thc_
(
getTensorInfo
)(
state
,
position
);
TensorInfo
<
real
>
positionInfo
=
thc_
(
getTensorInfo
)(
state
,
position
);
real
*
sizeData
=
THCTensor_
(
data
)(
state
,
size
);
real
*
sizeData
=
THCTensor_
(
data
)(
state
,
size
);
int64_t
*
count
Data
=
THC
udaLong
Tensor_data
(
state
,
count
);
real
*
maxPosition
Data
=
THCTensor_
(
data
)
(
state
,
maxPosition
);
const
int
N
=
THCudaLongTensor_nElement
(
state
,
output
);
const
int
N
=
THCudaLongTensor_nElement
(
state
,
output
);
int
grid
=
GET_BLOCKS
(
N
);
int
grid
=
GET_BLOCKS
(
N
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
cudaStream_t
stream
=
THCState_getCurrentStream
(
state
);
switch
(
positionInfo
.
dims
)
{
switch
(
positionInfo
.
dims
)
{
case
1
:
gridKernel
<
real
,
1
><<<
grid
,
NUM_THREADS
,
0
,
stream
>>>
(
outputData
,
positionInfo
,
sizeData
,
countData
,
C
,
N
);
break
;
case
1
:
gridKernel
<
real
,
1
><<<
grid
,
NUM_THREADS
,
0
,
stream
>>>
(
outputData
,
positionInfo
,
sizeData
,
maxPositionData
,
N
);
break
;
case
2
:
gridKernel
<
real
,
2
><<<
grid
,
NUM_THREADS
,
0
,
stream
>>>
(
outputData
,
positionInfo
,
sizeData
,
countData
,
C
,
N
);
break
;
case
2
:
gridKernel
<
real
,
2
><<<
grid
,
NUM_THREADS
,
0
,
stream
>>>
(
outputData
,
positionInfo
,
sizeData
,
maxPositionData
,
N
);
break
;
case
3
:
gridKernel
<
real
,
3
><<<
grid
,
NUM_THREADS
,
0
,
stream
>>>
(
outputData
,
positionInfo
,
sizeData
,
countData
,
C
,
N
);
break
;
case
3
:
gridKernel
<
real
,
3
><<<
grid
,
NUM_THREADS
,
0
,
stream
>>>
(
outputData
,
positionInfo
,
sizeData
,
maxPositionData
,
N
);
break
;
case
4
:
gridKernel
<
real
,
4
><<<
grid
,
NUM_THREADS
,
0
,
stream
>>>
(
outputData
,
positionInfo
,
sizeData
,
countData
,
C
,
N
);
break
;
default:
gridKernel
<
real
,
-
1
><<<
grid
,
NUM_THREADS
,
0
,
stream
>>>
(
outputData
,
positionInfo
,
sizeData
,
maxPositionData
,
N
);
break
;
default:
gridKernel
<
real
,
-
1
><<<
grid
,
NUM_THREADS
,
0
,
stream
>>>
(
outputData
,
positionInfo
,
sizeData
,
countData
,
C
,
N
);
break
;
}
}
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
real
C
=
1
;
for
(
ptrdiff_t
d
=
1
;
d
<
THCTensor_
(
nElement
)(
state
,
size
);
d
++
)
{
C
=
maxPositionData
[
d
]
/
sizeData
[
d
];
/* printf("%f", maxPositionData[d]); */
/* printf("%i", (int)*(maxPositionData)); */
/* C *= (int64_t) (*(maxPositionData + d) / *(sizeData + d)) + 1; */
}
return
C
;
}
}
#endif
#endif
torch_cluster/kernel/kernel.cu
View file @
eace3488
...
@@ -12,15 +12,18 @@
...
@@ -12,15 +12,18 @@
#include "THCGenerateAllTypes.h"
#include "THCGenerateAllTypes.h"
template
<
typename
Real
,
int
Dims
>
template
<
typename
Real
,
int
Dims
>
__global__
void
gridKernel
(
int64_t
*
output
,
TensorInfo
<
Real
>
position
,
Real
*
size
,
int64_t
*
count
,
const
int
C
,
const
int
N
)
{
__global__
void
gridKernel
(
int64_t
*
output
,
TensorInfo
<
Real
>
position
,
Real
*
size
,
Real
*
maxPosition
,
const
int
N
)
{
KERNEL_LOOP
(
i
,
N
)
{
KERNEL_LOOP
(
i
,
N
)
{
int
positionOffset
=
0
;
int
tmp
=
C
;
int64_t
c
=
0
;
int
positionOffset
=
0
;
IndexToOffset
<
Real
,
Dims
>::
compute
(
i
,
position
,
&
positionOffset
);
IndexToOffset
<
Real
,
Dims
>::
compute
(
i
,
position
,
&
positionOffset
);
for
(
int
d
=
0
;
d
<
position
.
size
[
position
.
dims
-
1
];
d
++
)
{
tmp
=
tmp
/
count
[
d
];
int
D
=
position
.
size
[
position
.
dims
-
1
];
c
+=
tmp
*
(
int64_t
)
(
position
.
data
[
positionOffset
+
d
]
/
size
[
d
]);
int
weight
=
1
;
int64_t
cluster
=
0
;
for
(
int
d
=
D
-
1
;
d
>=
0
;
d
--
)
{
cluster
+=
weight
*
(
int64_t
)
(
position
.
data
[
positionOffset
+
d
]
/
size
[
d
]);
weight
*=
(
int64_t
)
(
maxPosition
[
d
]
/
size
[
d
])
+
1
;
}
}
output
[
i
]
=
c
;
output
[
i
]
=
c
luster
;
}
}
}
}
...
...
torch_cluster/kernel/kernel.h
View file @
eace3488
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
extern
"C"
{
extern
"C"
{
#endif
#endif
void
cluster_grid_kernel_Float
(
THCState
*
state
,
int
C
,
THCudaLongTensor
*
output
,
THCudaTensor
*
position
,
THCudaTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_kernel_Float
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaTensor
*
position
,
THCudaTensor
*
size
,
THCudaTensor
*
maxPosition
);
void
cluster_grid_kernel_Double
(
THCState
*
state
,
int
C
,
THCudaLongTensor
*
output
,
THCudaDoubleTensor
*
position
,
THCudaDoubleTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_kernel_Double
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaDoubleTensor
*
position
,
THCudaDoubleTensor
*
size
,
THCuda
Double
Tensor
*
maxPosition
);
void
cluster_grid_kernel_Byte
(
THCState
*
state
,
int
C
,
THCudaLongTensor
*
output
,
THCudaByteTensor
*
position
,
THCudaByteTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_kernel_Byte
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaByteTensor
*
position
,
THCudaByteTensor
*
size
,
THCuda
Byte
Tensor
*
maxPosition
);
void
cluster_grid_kernel_Char
(
THCState
*
state
,
int
C
,
THCudaLongTensor
*
output
,
THCudaCharTensor
*
position
,
THCudaCharTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_kernel_Char
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaCharTensor
*
position
,
THCudaCharTensor
*
size
,
THCuda
Char
Tensor
*
maxPosition
);
void
cluster_grid_kernel_Short
(
THCState
*
state
,
int
C
,
THCudaLongTensor
*
output
,
THCudaShortTensor
*
position
,
THCudaShortTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_kernel_Short
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaShortTensor
*
position
,
THCudaShortTensor
*
size
,
THCuda
Short
Tensor
*
maxPosition
);
void
cluster_grid_kernel_Int
(
THCState
*
state
,
int
C
,
THCudaLongTensor
*
output
,
THCudaIntTensor
*
position
,
THCudaIntTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_kernel_Int
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaIntTensor
*
position
,
THCudaIntTensor
*
size
,
THCuda
Int
Tensor
*
maxPosition
);
void
cluster_grid_kernel_Long
(
THCState
*
state
,
int
C
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
position
,
THCudaLongTensor
*
size
,
THCudaLongTensor
*
count
);
int64_t
cluster_grid_kernel_Long
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
position
,
THCudaLongTensor
*
size
,
THCudaLongTensor
*
maxPosition
);
#ifdef __cplusplus
#ifdef __cplusplus
}
}
...
...
torch_cluster/src/cpu.h
View file @
eace3488
void
cluster_grid_Float
(
int
C
,
THLongTensor
*
output
,
THFloatTensor
*
position
,
THFloatTensor
*
size
,
TH
Long
Tensor
*
count
);
int64_t
cluster_grid_Float
(
THLongTensor
*
output
,
THFloatTensor
*
position
,
THFloatTensor
*
size
,
TH
Float
Tensor
*
maxPosition
);
void
cluster_grid_Double
(
int
C
,
THLongTensor
*
output
,
THDoubleTensor
*
position
,
THDoubleTensor
*
size
,
TH
Long
Tensor
*
count
);
int64_t
cluster_grid_Double
(
THLongTensor
*
output
,
THDoubleTensor
*
position
,
THDoubleTensor
*
size
,
TH
Double
Tensor
*
maxPosition
);
void
cluster_grid_Byte
(
int
C
,
THLongTensor
*
output
,
THByteTensor
*
position
,
THByteTensor
*
size
,
TH
Long
Tensor
*
count
);
int64_t
cluster_grid_Byte
(
THLongTensor
*
output
,
THByteTensor
*
position
,
THByteTensor
*
size
,
TH
Byte
Tensor
*
maxPosition
);
void
cluster_grid_Char
(
int
C
,
THLongTensor
*
output
,
THCharTensor
*
position
,
THCharTensor
*
size
,
TH
Long
Tensor
*
count
);
int64_t
cluster_grid_Char
(
THLongTensor
*
output
,
THCharTensor
*
position
,
THCharTensor
*
size
,
TH
Char
Tensor
*
maxPosition
);
void
cluster_grid_Short
(
int
C
,
THLongTensor
*
output
,
THShortTensor
*
position
,
THShortTensor
*
size
,
TH
Long
Tensor
*
count
);
int64_t
cluster_grid_Short
(
THLongTensor
*
output
,
THShortTensor
*
position
,
THShortTensor
*
size
,
TH
Short
Tensor
*
maxPosition
);
void
cluster_grid_Int
(
int
C
,
THLongTensor
*
output
,
THIntTensor
*
position
,
THIntTensor
*
size
,
TH
Long
Tensor
*
count
);
int64_t
cluster_grid_Int
(
THLongTensor
*
output
,
THIntTensor
*
position
,
THIntTensor
*
size
,
TH
Int
Tensor
*
maxPosition
);
void
cluster_grid_Long
(
int
C
,
THLongTensor
*
output
,
THLongTensor
*
position
,
THLongTensor
*
size
,
THLongTensor
*
count
);
int64_t
cluster_grid_Long
(
THLongTensor
*
output
,
THLongTensor
*
position
,
THLongTensor
*
size
,
THLongTensor
*
maxPosition
);
torch_cluster/src/cuda.h
View file @
eace3488
void
cluster_grid_cuda_Float
(
int
C
,
THCudaLongTensor
*
output
,
THCudaTensor
*
position
,
THCudaTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_cuda_Float
(
THCudaLongTensor
*
output
,
THCudaTensor
*
position
,
THCudaTensor
*
size
,
THCudaTensor
*
maxPosition
);
void
cluster_grid_cuda_Double
(
int
C
,
THCudaLongTensor
*
output
,
THCudaDoubleTensor
*
position
,
THCudaDoubleTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_cuda_Double
(
THCudaLongTensor
*
output
,
THCudaDoubleTensor
*
position
,
THCudaDoubleTensor
*
size
,
THCuda
Double
Tensor
*
maxPosition
);
void
cluster_grid_cuda_Byte
(
int
C
,
THCudaLongTensor
*
output
,
THCudaByteTensor
*
position
,
THCudaByteTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_cuda_Byte
(
THCudaLongTensor
*
output
,
THCudaByteTensor
*
position
,
THCudaByteTensor
*
size
,
THCuda
Byte
Tensor
*
maxPosition
);
void
cluster_grid_cuda_Char
(
int
C
,
THCudaLongTensor
*
output
,
THCudaCharTensor
*
position
,
THCudaCharTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_cuda_Char
(
THCudaLongTensor
*
output
,
THCudaCharTensor
*
position
,
THCudaCharTensor
*
size
,
THCuda
Char
Tensor
*
maxPosition
);
void
cluster_grid_cuda_Short
(
int
C
,
THCudaLongTensor
*
output
,
THCudaShortTensor
*
position
,
THCudaShortTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_cuda_Short
(
THCudaLongTensor
*
output
,
THCudaShortTensor
*
position
,
THCudaShortTensor
*
size
,
THCuda
Short
Tensor
*
maxPosition
);
void
cluster_grid_cuda_Int
(
int
C
,
THCudaLongTensor
*
output
,
THCudaIntTensor
*
position
,
THCudaIntTensor
*
size
,
THCuda
Long
Tensor
*
count
);
int64_t
cluster_grid_cuda_Int
(
THCudaLongTensor
*
output
,
THCudaIntTensor
*
position
,
THCudaIntTensor
*
size
,
THCuda
Int
Tensor
*
maxPosition
);
void
cluster_grid_cuda_Long
(
int
C
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
position
,
THCudaLongTensor
*
size
,
THCudaLongTensor
*
count
);
int64_t
cluster_grid_cuda_Long
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
position
,
THCudaLongTensor
*
size
,
THCudaLongTensor
*
maxPosition
);
torch_cluster/src/generic/cpu.c
View file @
eace3488
...
@@ -2,20 +2,27 @@
...
@@ -2,20 +2,27 @@
#define TH_GENERIC_FILE "generic/cpu.c"
#define TH_GENERIC_FILE "generic/cpu.c"
#else
#else
void
cluster_
(
grid
)(
int
C
,
THLongTensor
*
output
,
THTensor
*
position
,
THTensor
*
size
,
TH
Long
Tensor
*
count
)
{
int64_t
cluster_
(
grid
)(
THLongTensor
*
output
,
THTensor
*
position
,
THTensor
*
size
,
THTensor
*
maxPosition
)
{
real
*
size_data
=
size
->
storage
->
data
+
size
->
storageOffset
;
real
*
size_data
=
size
->
storage
->
data
+
size
->
storageOffset
;
int64_t
*
count_data
=
count
->
storage
->
data
+
count
->
storageOffset
;
real
*
maxPosition_data
=
maxPosition
->
storage
->
data
+
maxPosition
->
storageOffset
;
int64_t
D
,
d
,
i
,
c
,
tmp
;
D
=
THTensor_
(
nDimension
)(
position
);
int64_t
Dims
=
THTensor_
(
nDimension
)(
position
);
d
=
THTensor_
(
size
)(
position
,
D
-
1
);
int64_t
D
=
THTensor_
(
size
)(
position
,
Dims
-
1
);
TH_TENSOR_DIM_APPLY2
(
int64_t
,
output
,
real
,
position
,
D
-
1
,
tmp
=
C
;
c
=
0
;
TH_TENSOR_DIM_APPLY2
(
int64_t
,
output
,
real
,
position
,
Dims
-
1
,
for
(
i
=
0
;
i
<
d
;
i
++
)
{
int
weight
=
1
;
int64_t
cluster
=
0
;
tmp
=
tmp
/
*
(
count_data
+
i
);
for
(
int
d
=
D
-
1
;
d
>=
0
;
d
--
)
{
c
+=
tmp
*
(
int64_t
)
(
*
(
position_data
+
i
*
position_stride
)
/
*
(
size_data
+
i
));
cluster
+=
weight
*
(
int64_t
)
(
*
(
position_data
+
d
*
position_stride
)
/
*
(
size_data
+
d
));
weight
*=
(
int64_t
)
(
maxPosition_data
[
d
]
/
size_data
[
d
])
+
1
;
}
}
output_data
[
0
]
=
c
;
output_data
[
0
]
=
c
luster
;
)
)
int64_t
C
=
1
;
for
(
int
d
=
1
;
d
<
D
;
d
++
)
{
C
*=
(
int64_t
)
(
maxPosition_data
[
d
]
/
size_data
[
d
])
+
1
;
}
return
C
;
}
}
#endif
#endif
torch_cluster/src/generic/cuda.c
View file @
eace3488
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
#define THC_GENERIC_FILE "generic/cuda.c"
#define THC_GENERIC_FILE "generic/cuda.c"
#else
#else
void
cluster_
(
grid
)(
int
C
,
THCudaLongTensor
*
output
,
THCTensor
*
position
,
THCTensor
*
size
,
THC
udaLong
Tensor
*
count
)
{
int64_t
cluster_
(
grid
)(
THCudaLongTensor
*
output
,
THCTensor
*
position
,
THCTensor
*
size
,
THCTensor
*
maxPosition
)
{
cluster_kernel_
(
grid
)(
state
,
C
,
output
,
position
,
size
,
count
);
return
cluster_kernel_
(
grid
)(
state
,
output
,
position
,
size
,
maxPosition
);
}
}
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment