Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
da65da5b
Commit
da65da5b
authored
Mar 28, 2018
by
rusty1s
Browse files
changed arg order
parent
2a8339db
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
52 additions
and
31 deletions
+52
-31
test/utils/test_ffi.py
test/utils/test_ffi.py
+0
-0
test/utils/test_permute.py
test/utils/test_permute.py
+4
-0
torch_cluster/functions/grid.py
torch_cluster/functions/grid.py
+2
-2
torch_cluster/functions/serial.py
torch_cluster/functions/serial.py
+2
-2
torch_cluster/functions/utils/ffi.py
torch_cluster/functions/utils/ffi.py
+19
-2
torch_cluster/kernel/generic/serial.cu
torch_cluster/kernel/generic/serial.cu
+1
-1
torch_cluster/kernel/serial.h
torch_cluster/kernel/serial.h
+7
-7
torch_cluster/src/generic/serial_cpu.c
torch_cluster/src/generic/serial_cpu.c
+1
-1
torch_cluster/src/generic/serial_cuda.c
torch_cluster/src/generic/serial_cuda.c
+2
-2
torch_cluster/src/serial_cpu.h
torch_cluster/src/serial_cpu.h
+7
-7
torch_cluster/src/serial_cuda.h
torch_cluster/src/serial_cuda.h
+7
-7
No files found.
test/utils/test_ffi.py
0 → 100644
View file @
da65da5b
test/utils/test_permute.py
View file @
da65da5b
...
...
@@ -27,6 +27,9 @@ def test_permute_cpu():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
def
test_sort_gpu
():
# pragma: no cover
# Note that `sort` is not stable on the GPU, so it does not preserve the
# relative ordering of equivalent row elements. Thus, the expected column
# vector differs from the CPU version (which is stable).
row
=
torch
.
cuda
.
LongTensor
([
0
,
1
,
0
,
2
,
1
,
2
,
1
,
3
,
2
,
3
])
col
=
torch
.
cuda
.
LongTensor
([
1
,
0
,
2
,
0
,
2
,
1
,
3
,
1
,
3
,
2
])
row
,
col
=
sort
(
row
,
col
)
...
...
@@ -38,6 +41,7 @@ def test_sort_gpu(): # pragma: no cover
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
def
test_permute_gpu
():
# pragma: no cover
# Equivalent to `sort`, `permute` is not stable on the GPU (see above).
row
=
torch
.
cuda
.
LongTensor
([
0
,
1
,
0
,
2
,
1
,
2
,
1
,
3
,
2
,
3
])
col
=
torch
.
cuda
.
LongTensor
([
1
,
0
,
2
,
0
,
2
,
1
,
3
,
1
,
3
,
2
])
node_rid
=
torch
.
cuda
.
LongTensor
([
2
,
1
,
3
,
0
])
...
...
torch_cluster/functions/grid.py
View file @
da65da5b
...
...
@@ -2,7 +2,7 @@ from __future__ import division
import
torch
from
.utils.ffi
import
get_typed_func
from
.utils.ffi
import
_
get_typed_func
from
.utils.consecutive
import
consecutive
...
...
@@ -70,7 +70,7 @@ def _grid_cluster(position, size, cluster_size):
cluster
=
cluster_size
.
new
(
torch
.
Size
(
list
(
position
.
size
())[:
-
1
]))
cluster
=
cluster
.
unsqueeze
(
dim
=-
1
)
func
=
get_typed_func
(
'grid'
,
position
)
func
=
_
get_typed_func
(
'grid'
,
position
)
func
(
C
,
cluster
,
position
,
size
,
cluster_size
)
cluster
=
cluster
.
squeeze
(
dim
=-
1
)
...
...
torch_cluster/functions/serial.py
View file @
da65da5b
from
.utils.permute
import
permute
from
.utils.degree
import
node_degree
from
.utils.ffi
import
get_func
from
.utils.ffi
import
_
get_func
from
.utils.consecutive
import
consecutive
...
...
@@ -11,7 +11,7 @@ def serial_cluster(edge_index, batch=None, num_nodes=None):
degree
=
node_degree
(
row
,
num_nodes
,
out
=
row
.
new
())
cluster
=
edge_index
.
new
(
num_nodes
).
fill_
(
-
1
)
func
=
get_func
(
'random'
,
cluster
)
func
=
_
get_func
(
'random'
,
cluster
)
func
(
cluster
,
row
,
col
,
degree
)
cluster
,
u
=
consecutive
(
cluster
)
...
...
torch_cluster/functions/utils/ffi.py
View file @
da65da5b
from
..._ext
import
ffi
def
get_func
(
name
,
tensor
):
def
_
get_func
(
name
,
tensor
):
cuda
=
'_cuda'
if
tensor
.
is_cuda
else
''
return
getattr
(
ffi
,
'cluster_{}{}'
.
format
(
name
,
cuda
))
def
get_typed_func
(
name
,
tensor
):
def
_
get_typed_func
(
name
,
tensor
):
typename
=
type
(
tensor
).
__name__
.
replace
(
'Tensor'
,
''
)
cuda
=
'cuda_'
if
tensor
.
is_cuda
else
''
return
getattr
(
ffi
,
'cluster_{}_{}{}'
.
format
(
name
,
cuda
,
typename
))
def
ffi_serial
(
output
,
row
,
col
,
degree
,
weight
=
None
):
if
weight
is
None
:
func
=
_get_func
(
'serial'
,
row
)
func
(
output
,
row
,
col
,
degree
)
return
output
else
:
func
=
_get_typed_func
(
'serial'
,
weight
)
func
(
output
,
row
,
col
,
degree
,
weight
)
return
output
def
ffi_grid
(
C
,
output
,
position
,
size
,
count
):
func
=
_get_typed_func
(
'grid'
,
position
)
func
(
C
,
output
,
position
,
size
,
count
)
return
output
torch_cluster/kernel/generic/serial.cu
View file @
da65da5b
...
...
@@ -2,7 +2,7 @@
#define THC_GENERIC_FILE "generic/serial.cu"
#else
void
cluster_
(
serial
)(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCTensor
*
weight
,
THCudaLongTensor
*
degree
)
{
void
cluster_
(
serial
)(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCTensor
*
weight
)
{
}
#endif
...
...
torch_cluster/kernel/serial.h
View file @
da65da5b
...
...
@@ -4,13 +4,13 @@ extern "C" {
void
cluster_serial_kernel
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
);
void
cluster_serial_kernel_Float
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_kernel_Double
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCuda
Double
Tensor
*
weight
,
THCuda
Long
Tensor
*
degree
);
void
cluster_serial_kernel_Byte
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaByteTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_kernel_Char
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaCharTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_kernel_Short
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaShortTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_kernel_Int
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaIntTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_kernel_Long
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_kernel_Float
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaTensor
*
weight
);
void
cluster_serial_kernel_Double
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCuda
Long
Tensor
*
degree
,
THCuda
Double
Tensor
*
weight
);
void
cluster_serial_kernel_Byte
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaByteTensor
*
weight
);
void
cluster_serial_kernel_Char
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaCharTensor
*
weight
);
void
cluster_serial_kernel_Short
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaShortTensor
*
weight
);
void
cluster_serial_kernel_Int
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaIntTensor
*
weight
);
void
cluster_serial_kernel_Long
(
THCState
*
state
,
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaLongTensor
*
weight
);
#ifdef __cplusplus
}
...
...
torch_cluster/src/generic/serial_cpu.c
View file @
da65da5b
...
...
@@ -2,7 +2,7 @@
#define TH_GENERIC_FILE "generic/serial_cpu.c"
#else
void
cluster_
(
serial
)(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
TH
Tensor
*
weight
,
TH
LongTensor
*
degree
)
{
void
cluster_
(
serial
)(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
degree
,
THTensor
*
weight
)
{
real
*
weight_data
=
weight
->
storage
->
data
+
weight
->
storageOffset
;
real
max_weight
,
w
;
int64_t
d
,
c
;
...
...
torch_cluster/src/generic/serial_cuda.c
View file @
da65da5b
...
...
@@ -2,8 +2,8 @@
#define THC_GENERIC_FILE "generic/serial_cuda.c"
#else
void
cluster_
(
serial
)(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCTensor
*
weight
,
THCudaLongTensor
*
degree
)
{
cluster_kernel_
(
serial
)(
state
,
output
,
row
,
col
,
weight
,
degree
);
void
cluster_
(
serial
)(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCTensor
*
weight
)
{
cluster_kernel_
(
serial
)(
state
,
output
,
row
,
col
,
degree
,
weight
);
}
#endif
...
...
torch_cluster/src/serial_cpu.h
View file @
da65da5b
void
cluster_serial
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
degree
);
void
cluster_serial_Float
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THFloatTensor
*
weight
,
THLongTensor
*
degree
);
void
cluster_serial_Double
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
TH
Double
Tensor
*
weight
,
THLong
Tensor
*
degree
);
void
cluster_serial_Byte
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THByteTensor
*
weight
,
THLongTensor
*
degree
);
void
cluster_serial_Char
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THCharTensor
*
weight
,
THLongTensor
*
degree
);
void
cluster_serial_Short
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THShortTensor
*
weight
,
THLongTensor
*
degree
);
void
cluster_serial_Int
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THIntTensor
*
weight
,
THLongTensor
*
degree
);
void
cluster_serial_Long
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
weight
,
THLongTensor
*
degree
);
void
cluster_serial_Float
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
degree
,
THFloatTensor
*
weight
);
void
cluster_serial_Double
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
TH
Long
Tensor
*
degree
,
THDouble
Tensor
*
weight
);
void
cluster_serial_Byte
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
degree
,
THByteTensor
*
weight
);
void
cluster_serial_Char
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
degree
,
THCharTensor
*
weight
);
void
cluster_serial_Short
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
degree
,
THShortTensor
*
weight
);
void
cluster_serial_Int
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
degree
,
THIntTensor
*
weight
);
void
cluster_serial_Long
(
THLongTensor
*
output
,
THLongTensor
*
row
,
THLongTensor
*
col
,
THLongTensor
*
degree
,
THLongTensor
*
weight
);
torch_cluster/src/serial_cuda.h
View file @
da65da5b
void
cluster_serial_cuda
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
);
void
cluster_serial_cuda_Float
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_cuda_Double
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCuda
Double
Tensor
*
weight
,
THCuda
Long
Tensor
*
degree
);
void
cluster_serial_cuda_Byte
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaByteTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_cuda_Char
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaCharTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_cuda_Short
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaShortTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_cuda_Int
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaIntTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_cuda_Long
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
weight
,
THCudaLongTensor
*
degree
);
void
cluster_serial_cuda_Float
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaTensor
*
weight
);
void
cluster_serial_cuda_Double
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCuda
Long
Tensor
*
degree
,
THCuda
Double
Tensor
*
weight
);
void
cluster_serial_cuda_Byte
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaByteTensor
*
weight
);
void
cluster_serial_cuda_Char
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaCharTensor
*
weight
);
void
cluster_serial_cuda_Short
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaShortTensor
*
weight
);
void
cluster_serial_cuda_Int
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaIntTensor
*
weight
);
void
cluster_serial_cuda_Long
(
THCudaLongTensor
*
output
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
degree
,
THCudaLongTensor
*
weight
);
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment