Commit e1180216 authored by rusty1s's avatar rusty1s
Browse files

linting fixed

parent 751dd81d
import pytest import pytest
import torch import torch
import numpy as np import numpy as np
from torch_geometric.data import Batch
from numpy.testing import assert_almost_equal
from capsules.utils.sample import sample_farthest, batch_slices, radius_query_edges from torch_cluster.sample import (sample_farthest, batch_slices,
radius_query_edges)
from .utils import tensor, grad_dtypes, devices from .utils import tensor, grad_dtypes, devices
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_batch_slices(device): def test_batch_slices(device):
# test sample case for correctness # test sample case for correctness
batch = tensor([0] * 100 + [1] * 50 + [2] * 42, dtype=torch.long, device=device) batch = tensor(
[0] * 100 + [1] * 50 + [2] * 42, dtype=torch.long, device=device)
slices, sizes = batch_slices(batch, sizes=True) slices, sizes = batch_slices(batch, sizes=True)
slices, sizes = slices.cpu().tolist(), sizes.cpu().tolist() slices, sizes = slices.cpu().tolist(), sizes.cpu().tolist()
...@@ -33,10 +34,11 @@ def test_fps(dtype): ...@@ -33,10 +34,11 @@ def test_fps(dtype):
batch = tensor(batch, dtype=torch.long, device='cuda') batch = tensor(batch, dtype=torch.long, device='cuda')
pos = tensor(points + random_points, dtype=dtype, device='cuda') pos = tensor(points + random_points, dtype=dtype, device='cuda')
idx = sample_farthest(batch, pos, num_sampled=4, index=True) sample_farthest(batch, pos, num_sampled=4, index=True)
# needs update since isin is missing (sort indices, then compare?) # needs update since isin is missing (sort indices, then compare?)
# assert isin(idx, tensor([0, 1, 2, 3], dtype=torch.long, device='cuda'), False).all().cpu().item() == 1 # assert isin(idx, tensor([0, 1, 2, 3], dtype=torch.long, device='cuda'),
# False).all().cpu().item() == 1
# test variable number of points for each element in a batch # test variable number of points for each element in a batch
batch = [0] * 100 + [1] * 50 batch = [0] * 100 + [1] * 50
...@@ -67,7 +69,13 @@ def test_radius_edges(dtype): ...@@ -67,7 +69,13 @@ def test_radius_edges(dtype):
pos = tensor(points, dtype=dtype, device='cuda') pos = tensor(points, dtype=dtype, device='cuda')
query_pos = tensor(query_points, dtype=dtype, device='cuda') query_pos = tensor(query_points, dtype=dtype, device='cuda')
edge_index = radius_query_edges(batch, pos, query_batch, query_pos, radius=radius, max_num_neighbors=128) edge_index = radius_query_edges(
batch,
pos,
query_batch,
query_pos,
radius=radius,
max_num_neighbors=128)
row, col = edge_index row, col = edge_index
dist = torch.norm(pos[col] - query_pos[row], p=2, dim=1) dist = torch.norm(pos[col] - query_pos[row], p=2, dim=1)
assert (dist <= radius).all().item() assert (dist <= radius).all().item()
import torch import torch
from torch_scatter import scatter_add, scatter_max
from torch_geometric.utils import to_undirected from torch_geometric.utils import to_undirected
from torch_geometric.data import Batch
from torch_sparse import coalesce
from sample_cuda import farthest_point_sampling, query_radius, query_knn from sample_cuda import farthest_point_sampling, query_radius, query_knn
...@@ -11,7 +8,7 @@ def batch_slices(batch, sizes=False, include_ends=True): ...@@ -11,7 +8,7 @@ def batch_slices(batch, sizes=False, include_ends=True):
""" """
Calculates size, start and end indices for each element in a batch. Calculates size, start and end indices for each element in a batch.
""" """
size = scatter_add(torch.ones_like(batch), batch) size = torch.scatter_add_(torch.ones_like(batch), batch)
cumsum = torch.cumsum(size, dim=0) cumsum = torch.cumsum(size, dim=0)
starts = cumsum - size starts = cumsum - size
ends = cumsum - 1 ends = cumsum - 1
...@@ -26,10 +23,10 @@ def batch_slices(batch, sizes=False, include_ends=True): ...@@ -26,10 +23,10 @@ def batch_slices(batch, sizes=False, include_ends=True):
def sample_farthest(batch, pos, num_sampled, random_start=False, index=False): def sample_farthest(batch, pos, num_sampled, random_start=False, index=False):
""" """Samples a specified number of points for each element in a batch using
Samples a specified number of points for each element in a batch using farthest iterative point sampling and returns farthest iterative point sampling and returns a mask (or indices) for the
a mask (or indices) for the sampled points. sampled points. If there are less than num_sampled points in a point cloud
If there are less than num_sampled points in a point cloud all points are returned. all points are returned.
""" """
if not pos.is_cuda or not batch.is_cuda: if not pos.is_cuda or not batch.is_cuda:
raise NotImplementedError raise NotImplementedError
...@@ -67,9 +64,10 @@ def radius_query_edges(batch, ...@@ -67,9 +64,10 @@ def radius_query_edges(batch,
undirected=False): undirected=False):
if not pos.is_cuda: if not pos.is_cuda:
raise NotImplementedError raise NotImplementedError
assert pos.is_cuda and batch.is_cuda and query_pos.is_cuda and query_batch.is_cuda assert pos.is_cuda and batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous( assert query_pos.is_cuda and query_batch.is_cuda
) and query_pos.is_contiguous() and query_batch.is_contiguous() assert pos.is_contiguous() and batch.is_contiguous()
assert query_pos.is_contiguous() and query_batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True) slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1 batch_size = batch.max().item() + 1
...@@ -115,9 +113,10 @@ def knn_query_edges(batch, ...@@ -115,9 +113,10 @@ def knn_query_edges(batch,
undirected=False): undirected=False):
if not pos.is_cuda: if not pos.is_cuda:
raise NotImplementedError raise NotImplementedError
assert pos.is_cuda and batch.is_cuda and query_pos.is_cuda and query_batch.is_cuda assert pos.is_cuda and batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous( assert query_pos.is_cuda and query_batch.is_cuda
) and query_pos.is_contiguous() and query_batch.is_contiguous() assert pos.is_contiguous() and batch.is_contiguous()
assert query_pos.is_contiguous() and query_batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True) slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1 batch_size = batch.max().item() + 1
...@@ -147,5 +146,3 @@ def knn_query_edges(batch, ...@@ -147,5 +146,3 @@ def knn_query_edges(batch,
def knn_graph(batch, pos, num_neighbors, include_self=False, undirected=False): def knn_graph(batch, pos, num_neighbors, include_self=False, undirected=False):
return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self, return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self,
undirected) undirected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment