Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
4a61d70f
Commit
4a61d70f
authored
Mar 01, 2020
by
rusty1s
Browse files
better csrc api
parent
0e7f4b8e
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
102 additions
and
70 deletions
+102
-70
README.md
README.md
+1
-1
csrc/cpu/fps_cpu.cpp
csrc/cpu/fps_cpu.cpp
+6
-12
csrc/cpu/fps_cpu.h
csrc/cpu/fps_cpu.h
+1
-2
csrc/cpu/graclus_cpu.cpp
csrc/cpu/graclus_cpu.cpp
+23
-38
csrc/cpu/graclus_cpu.h
csrc/cpu/graclus_cpu.h
+2
-3
csrc/cuda/fps_cuda.cu
csrc/cuda/fps_cuda.cu
+41
-0
csrc/cuda/fps_cuda.h
csrc/cuda/fps_cuda.h
+6
-0
csrc/fps.cpp
csrc/fps.cpp
+3
-4
csrc/graclus.cpp
csrc/graclus.cpp
+5
-6
setup.py
setup.py
+1
-0
torch_cluster/fps.py
torch_cluster/fps.py
+3
-2
torch_cluster/graclus.py
torch_cluster/graclus.py
+9
-1
torch_cluster/rw.py
torch_cluster/rw.py
+1
-1
No files found.
README.md
View file @
4a61d70f
...
...
@@ -69,7 +69,7 @@ Then run:
pip install torch-cluster
```
When running in a docker container without
nvidia
driver, PyTorch needs to evaluate the compute capabilities and may fail.
When running in a docker container without
NVIDIA
driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
```
...
...
csrc/cpu/fps_cpu.cpp
View file @
4a61d70f
...
...
@@ -6,23 +6,16 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return
(
x
-
x
[
idx
]).
norm
(
2
,
1
);
}
torch
::
Tensor
fps_cpu
(
torch
::
Tensor
src
,
torch
::
optional
<
torch
::
Tensor
>
optional_ptr
,
double
ratio
,
torch
::
Tensor
fps_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
bool
random_start
)
{
CHECK_CPU
(
src
);
if
(
optional_ptr
.
has_value
())
{
CHECK_CPU
(
optional_ptr
.
value
());
CHECK_INPUT
(
optional_ptr
.
value
().
dim
()
==
1
);
}
CHECK_CPU
(
ptr
);
CHECK_INPUT
(
ptr
.
dim
()
==
1
);
AT_ASSERTM
(
ratio
>
0
and
ratio
<
1
,
"Invalid input"
);
if
(
!
optional_ptr
.
has_value
())
optional_ptr
=
torch
::
tensor
({
0
,
src
.
size
(
0
)},
src
.
options
().
dtype
(
torch
::
kLong
));
src
=
src
.
view
({
src
.
size
(
0
),
-
1
}).
contiguous
();
auto
ptr
=
optional_ptr
.
value
()
.
contiguous
();
ptr
=
ptr
.
contiguous
();
auto
batch_size
=
ptr
.
size
(
0
)
-
1
;
auto
deg
=
ptr
.
narrow
(
0
,
1
,
batch_size
)
-
ptr
.
narrow
(
0
,
0
,
batch_size
);
...
...
@@ -42,7 +35,7 @@ torch::Tensor fps_cpu(torch::Tensor src,
int64_t
start_idx
=
0
;
if
(
random_start
)
{
// TODO: GET RANDOM INTEGER
start_idx
=
rand
()
%
src
.
size
(
0
);
}
out_data
[
out_start
]
=
src_start
+
start_idx
;
...
...
@@ -56,5 +49,6 @@ torch::Tensor fps_cpu(torch::Tensor src,
src_start
=
src_end
,
out_start
=
out_end
;
}
return
out
;
}
csrc/cpu/fps_cpu.h
View file @
4a61d70f
...
...
@@ -2,6 +2,5 @@
#include <torch/extension.h>
torch
::
Tensor
fps_cpu
(
torch
::
Tensor
src
,
torch
::
optional
<
torch
::
Tensor
>
optional_ptr
,
double
ratio
,
torch
::
Tensor
fps_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
bool
random_start
);
csrc/cpu/graclus_cpu.cpp
View file @
4a61d70f
...
...
@@ -2,57 +2,42 @@
#include "utils.h"
torch
::
Tensor
graclus_cpu
(
torch
::
Tensor
row
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
,
int64_t
num_nodes
)
{
torch
::
Tensor
graclus_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
CHECK_CPU
(
row
);
CHECK_CPU
(
row
ptr
);
CHECK_CPU
(
col
);
CHECK_INPUT
(
row
.
dim
()
==
1
&&
col
.
dim
()
==
1
&&
row
.
numel
()
==
col
.
numel
()
);
CHECK_INPUT
(
row
ptr
.
dim
()
==
1
&&
col
.
dim
()
==
1
);
if
(
optional_weight
.
has_value
())
{
CHECK_CPU
(
optional_weight
.
value
());
CHECK_INPUT
(
optional_weight
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_weight
.
value
().
numel
()
==
col
.
numel
());
}
auto
mask
=
row
!=
col
;
row
=
row
.
masked_select
(
mask
),
col
=
col
.
masked_select
(
mask
);
if
(
optional_weight
.
has_value
())
optional_weight
=
optional_weight
.
value
().
masked_select
(
mask
);
auto
perm
=
torch
::
randperm
(
row
.
size
(
0
),
row
.
options
());
row
=
row
.
index_select
(
0
,
perm
);
col
=
col
.
index_select
(
0
,
perm
);
if
(
optional_weight
.
has_value
())
optional_weight
=
optional_weight
.
value
().
index_select
(
0
,
perm
);
std
::
tie
(
row
,
perm
)
=
row
.
sort
();
col
=
col
.
index_select
(
0
,
perm
);
if
(
optional_weight
.
has_value
())
optional_weight
=
optional_weight
.
value
().
index_select
(
0
,
perm
);
auto
rowptr
=
torch
::
zeros
(
num_nodes
,
row
.
options
());
rowptr
=
rowptr
.
scatter_add_
(
0
,
row
,
torch
::
ones_like
(
row
)).
cumsum
(
0
);
rowptr
=
torch
::
cat
({
torch
::
zeros
(
1
,
row
.
options
()),
rowptr
},
0
);
perm
=
torch
::
randperm
(
num_nodes
,
row
.
options
());
auto
out
=
torch
::
full
(
num_nodes
,
-
1
,
row
.
options
());
int64_t
num_nodes
=
rowptr
.
numel
()
-
1
;
auto
out
=
torch
::
full
(
num_nodes
,
-
1
,
rowptr
.
options
());
auto
node_perm
=
torch
::
randperm
(
num_nodes
,
rowptr
.
options
());
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
auto
perm_data
=
perm
.
data_ptr
<
int64_t
>
();
auto
node_
perm_data
=
node_
perm
.
data_ptr
<
int64_t
>
();
auto
out_data
=
out
.
data_ptr
<
int64_t
>
();
if
(
!
optional_weight
.
has_value
())
{
for
(
auto
i
=
0
;
i
<
num_nodes
;
i
++
)
{
auto
u
=
perm_data
[
i
];
for
(
int64_t
n
=
0
;
n
<
num_nodes
;
n
++
)
{
auto
u
=
node_
perm_data
[
n
];
if
(
out_data
[
u
]
>=
0
)
continue
;
out_data
[
u
]
=
u
;
for
(
auto
j
=
rowptr_data
[
u
];
j
<
rowptr_data
[
u
+
1
];
j
++
)
{
auto
v
=
col_data
[
j
];
int64_t
row_start
=
rowptr_data
[
u
],
row_end
=
rowptr_data
[
u
+
1
];
auto
edge_perm
=
torch
::
randperm
(
row_end
-
row_start
,
rowptr
.
options
());
auto
edge_perm_data
=
edge_perm
.
data_ptr
<
int64_t
>
();
for
(
auto
e
=
0
;
e
<
row_end
-
row_start
;
e
++
)
{
auto
v
=
col_data
[
row_start
+
edge_perm_data
[
e
]];
if
(
out_data
[
v
]
>=
0
)
continue
;
...
...
@@ -67,8 +52,8 @@ torch::Tensor graclus_cpu(torch::Tensor row, torch::Tensor col,
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"weighted_graclus"
,
[
&
]
{
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
for
(
auto
i
=
0
;
i
<
num_nodes
;
i
++
)
{
auto
u
=
perm_data
[
i
];
for
(
auto
n
=
0
;
n
<
num_nodes
;
n
++
)
{
auto
u
=
node_
perm_data
[
n
];
if
(
out_data
[
u
]
>=
0
)
continue
;
...
...
@@ -76,15 +61,15 @@ torch::Tensor graclus_cpu(torch::Tensor row, torch::Tensor col,
auto
v_max
=
u
;
scalar_t
w_max
=
(
scalar_t
)
0.
;
for
(
auto
j
=
rowptr_data
[
u
];
j
<
rowptr_data
[
u
+
1
];
j
++
)
{
auto
v
=
col_data
[
j
];
for
(
auto
e
=
rowptr_data
[
u
];
e
<
rowptr_data
[
u
+
1
];
e
++
)
{
auto
v
=
col_data
[
e
];
if
(
out_data
[
v
]
>=
0
)
continue
;
if
(
weight_data
[
j
]
>=
w_max
)
{
if
(
weight_data
[
e
]
>=
w_max
)
{
v_max
=
v
;
w_max
=
weight_data
[
j
];
w_max
=
weight_data
[
e
];
}
}
...
...
csrc/cpu/graclus_cpu.h
View file @
4a61d70f
...
...
@@ -2,6 +2,5 @@
#include <torch/extension.h>
torch
::
Tensor
graclus_cpu
(
torch
::
Tensor
row
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
,
int64_t
num_nodes
);
torch
::
Tensor
graclus_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
);
csrc/cuda/fps_cuda.cu
0 → 100644
View file @
4a61d70f
#include "fps_cuda.h"
#include "utils.cuh"
inline
torch
::
Tensor
get_dist
(
torch
::
Tensor
x
,
int64_t
idx
)
{
return
(
x
-
x
[
idx
]).
norm
(
2
,
1
);
}
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
bool
random_start
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
ptr
);
CHECK_INPUT
(
ptr
.
dim
()
==
1
);
AT_ASSERTM
(
ratio
>
0
and
ratio
<
1
,
"Invalid input"
);
src
=
src
.
view
({
src
.
size
(
0
),
-
1
}).
contiguous
();
ptr
=
ptr
.
contiguous
();
auto
batch_size
=
ptr
.
size
(
0
)
-
1
;
auto
deg
=
ptr
.
narrow
(
0
,
1
,
batch_size
)
-
ptr
.
narrow
(
0
,
0
,
batch_size
);
auto
out_ptr
=
deg
.
toType
(
torch
::
kFloat
)
*
(
float
)
ratio
;
out_ptr
=
out_ptr
.
ceil
().
toType
(
torch
::
kLong
).
cumsum
(
0
);
out_ptr
=
torch
::
cat
({
torch
.
zeros
(
1
,
ptr
.
options
()),
out_ptr
},
0
);
torch
::
Tensor
start
;
if
(
random_start
)
{
start
=
at
::
rand
(
batch_size
,
src
.
options
());
start
=
(
start
*
deg
.
toType
(
torch
::
kFloat
)).
toType
(
torch
::
kLong
);
}
else
{
start
=
torch
::
zeros
(
batch_size
,
ptr
.
options
());
}
auto
out
=
torch
::
empty
(
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
()[
0
],
ptr
.
options
());
auto
ptr_data
=
ptr
.
data_ptr
<
int64_t
>
();
auto
out_ptr_data
=
out_ptr
.
data_ptr
<
int64_t
>
();
auto
out_data
=
out
.
data_ptr
<
int64_t
>
();
return
out
;
}
csrc/cuda/fps_cuda.h
0 → 100644
View file @
4a61d70f
#pragma once
#include <torch/extension.h>
torch
::
Tensor
fps_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
bool
random_start
);
csrc/fps.cpp
View file @
4a61d70f
...
...
@@ -11,17 +11,16 @@
PyMODINIT_FUNC
PyInit__fps
(
void
)
{
return
NULL
;
}
#endif
torch
::
Tensor
fps
(
torch
::
Tensor
src
,
torch
::
optional
<
torch
::
Tensor
>
optional_ptr
,
double
ratio
,
torch
::
Tensor
fps
(
torch
::
Tensor
src
,
torch
::
Tensor
ptr
,
double
ratio
,
bool
random_start
)
{
if
(
src
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
fps_cuda
(
src
,
optional_
ptr
,
ratio
,
random_start
);
return
fps_cuda
(
src
,
ptr
,
ratio
,
random_start
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
fps_cpu
(
src
,
optional_
ptr
,
ratio
,
random_start
);
return
fps_cpu
(
src
,
ptr
,
ratio
,
random_start
);
}
}
...
...
csrc/graclus.cpp
View file @
4a61d70f
...
...
@@ -11,17 +11,16 @@
PyMODINIT_FUNC
PyInit__graclus
(
void
)
{
return
NULL
;
}
#endif
torch
::
Tensor
graclus
(
torch
::
Tensor
row
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
,
int64_t
num_nodes
)
{
if
(
row
.
device
().
is_cuda
())
{
torch
::
Tensor
graclus
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
graclus_cuda
(
row
,
col
,
optional_weight
,
num_nodes
);
return
graclus_cuda
(
row
ptr
,
col
,
optional_weight
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
graclus_cpu
(
row
,
col
,
optional_weight
,
num_nodes
);
return
graclus_cpu
(
row
ptr
,
col
,
optional_weight
);
}
}
...
...
setup.py
View file @
4a61d70f
...
...
@@ -26,6 +26,7 @@ def get_extensions():
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'-arch=sm_35'
,
'--expt-relaxed-constexpr'
]
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
extensions_dir
=
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'csrc'
)
...
...
torch_cluster/fps.py
View file @
4a61d70f
...
...
@@ -33,9 +33,8 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
index = fps(src, batch, ratio=0.5)
"""
ptr
:
Optional
[
torch
.
Tensor
]
=
None
if
batch
is
not
None
:
assert
src
.
size
(
0
)
==
batch
.
size
(
0
)
assert
src
.
size
(
0
)
==
batch
.
numel
(
)
batch_size
=
int
(
batch
.
max
())
+
1
deg
=
src
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
...
...
@@ -43,5 +42,7 @@ def fps(src: torch.Tensor, batch: Optional[torch.Tensor] = None,
ptr
=
src
.
new_zeros
(
batch_size
+
1
,
dtype
=
torch
.
long
)
deg
.
cumsum
(
0
,
out
=
ptr
[
1
:])
else
:
ptr
=
torch
.
tensor
([
0
,
src
.
size
(
0
)],
device
=
src
.
device
)
return
torch
.
ops
.
torch_cluster
.
fps
(
src
,
ptr
,
ratio
,
random_start
)
torch_cluster/graclus.py
View file @
4a61d70f
...
...
@@ -32,4 +32,12 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
if
num_nodes
is
None
:
num_nodes
=
max
(
int
(
row
.
max
()),
int
(
col
.
max
()))
+
1
return
torch
.
ops
.
torch_cluster
.
graclus
(
row
,
col
,
weight
,
num_nodes
)
perm
=
torch
.
argsort
(
row
*
num_nodes
+
col
)
row
,
col
=
row
[
perm
],
col
[
perm
]
deg
=
row
.
new_zeros
(
num_nodes
)
deg
.
scatter_add_
(
0
,
row
,
torch
.
ones_like
(
row
))
rowptr
=
row
.
new_zeros
(
num_nodes
+
1
)
deg
.
cumsum
(
0
,
out
=
rowptr
[
1
:])
return
torch
.
ops
.
torch_cluster
.
graclus
(
rowptr
,
col
,
weight
)
torch_cluster/rw.py
View file @
4a61d70f
...
...
@@ -35,7 +35,7 @@ def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
num_nodes
=
max
(
int
(
row
.
max
()),
int
(
col
.
max
()))
+
1
if
coalesced
:
_
,
perm
=
torch
.
sort
(
row
*
num_nodes
+
col
)
perm
=
torch
.
arg
sort
(
row
*
num_nodes
+
col
)
row
,
col
=
row
[
perm
],
col
[
perm
]
deg
=
row
.
new_zeros
(
num_nodes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment