Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
e1180216
Commit
e1180216
authored
Nov 15, 2018
by
rusty1s
Browse files
linting fixed
parent
751dd81d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
23 deletions
+28
-23
test/test_sample.py
test/test_sample.py
+15
-7
torch_cluster/sample.py
torch_cluster/sample.py
+13
-16
No files found.
test/test_sample.py
View file @
e1180216
import
pytest
import
pytest
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
torch_geometric.data
import
Batch
from
numpy.testing
import
assert_almost_equal
from
capsules.utils.sample
import
sample_farthest
,
batch_slices
,
radius_query_edges
from
torch_cluster.sample
import
(
sample_farthest
,
batch_slices
,
radius_query_edges
)
from
.utils
import
tensor
,
grad_dtypes
,
devices
from
.utils
import
tensor
,
grad_dtypes
,
devices
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_batch_slices
(
device
):
def
test_batch_slices
(
device
):
# test sample case for correctness
# test sample case for correctness
batch
=
tensor
([
0
]
*
100
+
[
1
]
*
50
+
[
2
]
*
42
,
dtype
=
torch
.
long
,
device
=
device
)
batch
=
tensor
(
[
0
]
*
100
+
[
1
]
*
50
+
[
2
]
*
42
,
dtype
=
torch
.
long
,
device
=
device
)
slices
,
sizes
=
batch_slices
(
batch
,
sizes
=
True
)
slices
,
sizes
=
batch_slices
(
batch
,
sizes
=
True
)
slices
,
sizes
=
slices
.
cpu
().
tolist
(),
sizes
.
cpu
().
tolist
()
slices
,
sizes
=
slices
.
cpu
().
tolist
(),
sizes
.
cpu
().
tolist
()
...
@@ -33,10 +34,11 @@ def test_fps(dtype):
...
@@ -33,10 +34,11 @@ def test_fps(dtype):
batch
=
tensor
(
batch
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
batch
=
tensor
(
batch
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
pos
=
tensor
(
points
+
random_points
,
dtype
=
dtype
,
device
=
'cuda'
)
pos
=
tensor
(
points
+
random_points
,
dtype
=
dtype
,
device
=
'cuda'
)
idx
=
sample_farthest
(
batch
,
pos
,
num_sampled
=
4
,
index
=
True
)
sample_farthest
(
batch
,
pos
,
num_sampled
=
4
,
index
=
True
)
# needs update since isin is missing (sort indices, then compare?)
# needs update since isin is missing (sort indices, then compare?)
# assert isin(idx, tensor([0, 1, 2, 3], dtype=torch.long, device='cuda'), False).all().cpu().item() == 1
# assert isin(idx, tensor([0, 1, 2, 3], dtype=torch.long, device='cuda'),
# False).all().cpu().item() == 1
# test variable number of points for each element in a batch
# test variable number of points for each element in a batch
batch
=
[
0
]
*
100
+
[
1
]
*
50
batch
=
[
0
]
*
100
+
[
1
]
*
50
...
@@ -67,7 +69,13 @@ def test_radius_edges(dtype):
...
@@ -67,7 +69,13 @@ def test_radius_edges(dtype):
pos
=
tensor
(
points
,
dtype
=
dtype
,
device
=
'cuda'
)
pos
=
tensor
(
points
,
dtype
=
dtype
,
device
=
'cuda'
)
query_pos
=
tensor
(
query_points
,
dtype
=
dtype
,
device
=
'cuda'
)
query_pos
=
tensor
(
query_points
,
dtype
=
dtype
,
device
=
'cuda'
)
edge_index
=
radius_query_edges
(
batch
,
pos
,
query_batch
,
query_pos
,
radius
=
radius
,
max_num_neighbors
=
128
)
edge_index
=
radius_query_edges
(
batch
,
pos
,
query_batch
,
query_pos
,
radius
=
radius
,
max_num_neighbors
=
128
)
row
,
col
=
edge_index
row
,
col
=
edge_index
dist
=
torch
.
norm
(
pos
[
col
]
-
query_pos
[
row
],
p
=
2
,
dim
=
1
)
dist
=
torch
.
norm
(
pos
[
col
]
-
query_pos
[
row
],
p
=
2
,
dim
=
1
)
assert
(
dist
<=
radius
).
all
().
item
()
assert
(
dist
<=
radius
).
all
().
item
()
torch_cluster/sample.py
View file @
e1180216
import
torch
import
torch
from
torch_scatter
import
scatter_add
,
scatter_max
from
torch_geometric.utils
import
to_undirected
from
torch_geometric.utils
import
to_undirected
from
torch_geometric.data
import
Batch
from
torch_sparse
import
coalesce
from
sample_cuda
import
farthest_point_sampling
,
query_radius
,
query_knn
from
sample_cuda
import
farthest_point_sampling
,
query_radius
,
query_knn
...
@@ -11,7 +8,7 @@ def batch_slices(batch, sizes=False, include_ends=True):
...
@@ -11,7 +8,7 @@ def batch_slices(batch, sizes=False, include_ends=True):
"""
"""
Calculates size, start and end indices for each element in a batch.
Calculates size, start and end indices for each element in a batch.
"""
"""
size
=
scatter_add
(
torch
.
ones_like
(
batch
),
batch
)
size
=
torch
.
scatter_add
_
(
torch
.
ones_like
(
batch
),
batch
)
cumsum
=
torch
.
cumsum
(
size
,
dim
=
0
)
cumsum
=
torch
.
cumsum
(
size
,
dim
=
0
)
starts
=
cumsum
-
size
starts
=
cumsum
-
size
ends
=
cumsum
-
1
ends
=
cumsum
-
1
...
@@ -26,10 +23,10 @@ def batch_slices(batch, sizes=False, include_ends=True):
...
@@ -26,10 +23,10 @@ def batch_slices(batch, sizes=False, include_ends=True):
def
sample_farthest
(
batch
,
pos
,
num_sampled
,
random_start
=
False
,
index
=
False
):
def
sample_farthest
(
batch
,
pos
,
num_sampled
,
random_start
=
False
,
index
=
False
):
"""
"""
Samples a specified number of points for each element in a batch using
Samples a specified number of points for each element in a batch using
farthest iterative point sampling and returns
farthest iterative point sampling and returns
a mask (or indices) for the
a mask (or indices) for the sampled points.
sampled points. If there are less than num_sampled points in a point cloud
If there are less than num_sampled points in a point cloud
all points are returned.
all points are returned.
"""
"""
if
not
pos
.
is_cuda
or
not
batch
.
is_cuda
:
if
not
pos
.
is_cuda
or
not
batch
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -67,9 +64,10 @@ def radius_query_edges(batch,
...
@@ -67,9 +64,10 @@ def radius_query_edges(batch,
undirected
=
False
):
undirected
=
False
):
if
not
pos
.
is_cuda
:
if
not
pos
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
assert
pos
.
is_cuda
and
batch
.
is_cuda
and
query_pos
.
is_cuda
and
query_batch
.
is_cuda
assert
pos
.
is_cuda
and
batch
.
is_cuda
assert
pos
.
is_contiguous
()
and
batch
.
is_contiguous
(
assert
query_pos
.
is_cuda
and
query_batch
.
is_cuda
)
and
query_pos
.
is_contiguous
()
and
query_batch
.
is_contiguous
()
assert
pos
.
is_contiguous
()
and
batch
.
is_contiguous
()
assert
query_pos
.
is_contiguous
()
and
query_batch
.
is_contiguous
()
slices
,
sizes
=
batch_slices
(
batch
,
sizes
=
True
)
slices
,
sizes
=
batch_slices
(
batch
,
sizes
=
True
)
batch_size
=
batch
.
max
().
item
()
+
1
batch_size
=
batch
.
max
().
item
()
+
1
...
@@ -115,9 +113,10 @@ def knn_query_edges(batch,
...
@@ -115,9 +113,10 @@ def knn_query_edges(batch,
undirected
=
False
):
undirected
=
False
):
if
not
pos
.
is_cuda
:
if
not
pos
.
is_cuda
:
raise
NotImplementedError
raise
NotImplementedError
assert
pos
.
is_cuda
and
batch
.
is_cuda
and
query_pos
.
is_cuda
and
query_batch
.
is_cuda
assert
pos
.
is_cuda
and
batch
.
is_cuda
assert
pos
.
is_contiguous
()
and
batch
.
is_contiguous
(
assert
query_pos
.
is_cuda
and
query_batch
.
is_cuda
)
and
query_pos
.
is_contiguous
()
and
query_batch
.
is_contiguous
()
assert
pos
.
is_contiguous
()
and
batch
.
is_contiguous
()
assert
query_pos
.
is_contiguous
()
and
query_batch
.
is_contiguous
()
slices
,
sizes
=
batch_slices
(
batch
,
sizes
=
True
)
slices
,
sizes
=
batch_slices
(
batch
,
sizes
=
True
)
batch_size
=
batch
.
max
().
item
()
+
1
batch_size
=
batch
.
max
().
item
()
+
1
...
@@ -147,5 +146,3 @@ def knn_query_edges(batch,
...
@@ -147,5 +146,3 @@ def knn_query_edges(batch,
def
knn_graph
(
batch
,
pos
,
num_neighbors
,
include_self
=
False
,
undirected
=
False
):
def
knn_graph
(
batch
,
pos
,
num_neighbors
,
include_self
=
False
,
undirected
=
False
):
return
knn_query_edges
(
batch
,
pos
,
batch
,
pos
,
num_neighbors
,
include_self
,
return
knn_query_edges
(
batch
,
pos
,
batch
,
pos
,
num_neighbors
,
include_self
,
undirected
)
undirected
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment