Commit e1180216 authored by rusty1s's avatar rusty1s
Browse files

linting fixed

parent 751dd81d
import pytest
import torch
import numpy as np
from torch_geometric.data import Batch
from numpy.testing import assert_almost_equal
from capsules.utils.sample import sample_farthest, batch_slices, radius_query_edges
from torch_cluster.sample import (sample_farthest, batch_slices,
radius_query_edges)
from .utils import tensor, grad_dtypes, devices
@pytest.mark.parametrize('device', devices)
def test_batch_slices(device):
# test sample case for correctness
batch = tensor([0] * 100 + [1] * 50 + [2] * 42, dtype=torch.long, device=device)
batch = tensor(
[0] * 100 + [1] * 50 + [2] * 42, dtype=torch.long, device=device)
slices, sizes = batch_slices(batch, sizes=True)
slices, sizes = slices.cpu().tolist(), sizes.cpu().tolist()
......@@ -33,10 +34,11 @@ def test_fps(dtype):
batch = tensor(batch, dtype=torch.long, device='cuda')
pos = tensor(points + random_points, dtype=dtype, device='cuda')
idx = sample_farthest(batch, pos, num_sampled=4, index=True)
sample_farthest(batch, pos, num_sampled=4, index=True)
# needs update since isin is missing (sort indices, then compare?)
# assert isin(idx, tensor([0, 1, 2, 3], dtype=torch.long, device='cuda'), False).all().cpu().item() == 1
# assert isin(idx, tensor([0, 1, 2, 3], dtype=torch.long, device='cuda'),
# False).all().cpu().item() == 1
# test variable number of points for each element in a batch
batch = [0] * 100 + [1] * 50
......@@ -67,7 +69,13 @@ def test_radius_edges(dtype):
pos = tensor(points, dtype=dtype, device='cuda')
query_pos = tensor(query_points, dtype=dtype, device='cuda')
edge_index = radius_query_edges(batch, pos, query_batch, query_pos, radius=radius, max_num_neighbors=128)
edge_index = radius_query_edges(
batch,
pos,
query_batch,
query_pos,
radius=radius,
max_num_neighbors=128)
row, col = edge_index
dist = torch.norm(pos[col] - query_pos[row], p=2, dim=1)
assert (dist <= radius).all().item()
import torch
from torch_scatter import scatter_add, scatter_max
from torch_geometric.utils import to_undirected
from torch_geometric.data import Batch
from torch_sparse import coalesce
from sample_cuda import farthest_point_sampling, query_radius, query_knn
......@@ -11,7 +8,7 @@ def batch_slices(batch, sizes=False, include_ends=True):
"""
Calculates size, start and end indices for each element in a batch.
"""
size = scatter_add(torch.ones_like(batch), batch)
size = torch.scatter_add_(torch.ones_like(batch), batch)
cumsum = torch.cumsum(size, dim=0)
starts = cumsum - size
ends = cumsum - 1
......@@ -26,10 +23,10 @@ def batch_slices(batch, sizes=False, include_ends=True):
def sample_farthest(batch, pos, num_sampled, random_start=False, index=False):
"""
Samples a specified number of points for each element in a batch using farthest iterative point sampling and returns
a mask (or indices) for the sampled points.
If there are less than num_sampled points in a point cloud all points are returned.
"""Samples a specified number of points for each element in a batch using
farthest iterative point sampling and returns a mask (or indices) for the
sampled points. If there are less than num_sampled points in a point cloud
all points are returned.
"""
if not pos.is_cuda or not batch.is_cuda:
raise NotImplementedError
......@@ -67,9 +64,10 @@ def radius_query_edges(batch,
undirected=False):
if not pos.is_cuda:
raise NotImplementedError
assert pos.is_cuda and batch.is_cuda and query_pos.is_cuda and query_batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous(
) and query_pos.is_contiguous() and query_batch.is_contiguous()
assert pos.is_cuda and batch.is_cuda
assert query_pos.is_cuda and query_batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous()
assert query_pos.is_contiguous() and query_batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1
......@@ -115,9 +113,10 @@ def knn_query_edges(batch,
undirected=False):
if not pos.is_cuda:
raise NotImplementedError
assert pos.is_cuda and batch.is_cuda and query_pos.is_cuda and query_batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous(
) and query_pos.is_contiguous() and query_batch.is_contiguous()
assert pos.is_cuda and batch.is_cuda
assert query_pos.is_cuda and query_batch.is_cuda
assert pos.is_contiguous() and batch.is_contiguous()
assert query_pos.is_contiguous() and query_batch.is_contiguous()
slices, sizes = batch_slices(batch, sizes=True)
batch_size = batch.max().item() + 1
......@@ -147,5 +146,3 @@ def knn_query_edges(batch,
def knn_graph(batch, pos, num_neighbors, include_self=False, undirected=False):
return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self,
undirected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment