Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
61ef00a6
Commit
61ef00a6
authored
Nov 09, 2021
by
rusty1s
Browse files
fix knn and radius for unequal batch sizes
parent
10049daf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
30 deletions
+24
-30
setup.py
setup.py
+4
-0
torch_cluster/knn.py
torch_cluster/knn.py
+10
-15
torch_cluster/radius.py
torch_cluster/radius.py
+10
-15
No files found.
setup.py
View file @
61ef00a6
...
@@ -31,6 +31,8 @@ def get_extensions():
...
@@ -31,6 +31,8 @@ def get_extensions():
for
main
,
suffix
in
product
(
main_files
,
suffices
):
for
main
,
suffix
in
product
(
main_files
,
suffices
):
define_macros
=
[]
define_macros
=
[]
extra_compile_args
=
{
'cxx'
:
[
'-O2'
]}
extra_compile_args
=
{
'cxx'
:
[
'-O2'
]}
if
not
os
.
name
==
'nt'
:
# Not on Windows:
extra_compile_args
[
'cxx'
]
+=
[
'-Wno-sign-compare'
]
extra_link_args
=
[
'-s'
]
extra_link_args
=
[
'-s'
]
info
=
parallel_info
()
info
=
parallel_info
()
...
@@ -49,6 +51,8 @@ def get_extensions():
...
@@ -49,6 +51,8 @@ def get_extensions():
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
,
'-O2'
]
nvcc_flags
+=
[
'--expt-relaxed-constexpr'
,
'-O2'
]
if
not
os
.
name
==
'nt'
:
# Not on Windows:
nvcc_flags
+=
[
'-Wno-sign-compare'
]
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
extra_compile_args
[
'nvcc'
]
=
nvcc_flags
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
...
...
torch_cluster/knn.py
View file @
61ef00a6
...
@@ -50,27 +50,22 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
...
@@ -50,27 +50,22 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
batch_size
=
1
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_y
is
not
None
:
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_size
>
1
:
assert
batch_x
is
not
None
assert
batch_y
is
not
None
arange
=
torch
.
arange
(
batch_size
+
1
,
device
=
x
.
device
)
ptr_x
=
torch
.
bucketize
(
arange
,
batch_x
)
ptr_y
=
torch
.
bucketize
(
arange
,
batch_y
)
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
,
return
torch
.
ops
.
torch_cluster
.
knn
(
x
,
y
,
ptr_x
,
ptr_y
,
k
,
cosine
,
num_workers
)
num_workers
)
...
...
torch_cluster/radius.py
View file @
61ef00a6
...
@@ -50,27 +50,22 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
...
@@ -50,27 +50,22 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
y
=
y
.
view
(
-
1
,
1
)
if
y
.
dim
()
==
1
else
y
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
x
,
y
=
x
.
contiguous
(),
y
.
contiguous
()
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
batch_size
=
1
if
batch_x
is
not
None
:
if
batch_x
is
not
None
:
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
assert
x
.
size
(
0
)
==
batch_x
.
numel
()
batch_size
=
int
(
batch_x
.
max
())
+
1
batch_size
=
int
(
batch_x
.
max
())
+
1
deg
=
x
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_x
,
torch
.
ones_like
(
batch_x
))
ptr_x
=
deg
.
new_zeros
(
batch_size
+
1
)
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_x
[
1
:])
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_y
is
not
None
:
if
batch_y
is
not
None
:
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
assert
y
.
size
(
0
)
==
batch_y
.
numel
()
batch_size
=
int
(
batch_y
.
max
())
+
1
batch_size
=
max
(
batch_size
,
int
(
batch_y
.
max
())
+
1
)
deg
=
y
.
new_zeros
(
batch_size
,
dtype
=
torch
.
long
)
deg
.
scatter_add_
(
0
,
batch_y
,
torch
.
ones_like
(
batch_y
))
ptr_y
=
deg
.
new_zeros
(
batch_size
+
1
)
ptr_x
:
Optional
[
torch
.
Tensor
]
=
None
torch
.
cumsum
(
deg
,
0
,
out
=
ptr_y
[
1
:])
ptr_y
:
Optional
[
torch
.
Tensor
]
=
None
if
batch_size
>
1
:
assert
batch_x
is
not
None
assert
batch_y
is
not
None
arange
=
torch
.
arange
(
batch_size
+
1
,
device
=
x
.
device
)
ptr_x
=
torch
.
bucketize
(
arange
,
batch_x
)
ptr_y
=
torch
.
bucketize
(
arange
,
batch_y
)
return
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
return
torch
.
ops
.
torch_cluster
.
radius
(
x
,
y
,
ptr_x
,
ptr_y
,
r
,
max_num_neighbors
,
num_workers
)
max_num_neighbors
,
num_workers
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment