Commit 66eef3b3 authored by rusty1s's avatar rusty1s
Browse files

removed undirected call to pytorch geometric

parent 50909651
import torch import torch
from torch_geometric.utils import to_undirected
from sample_cuda import farthest_point_sampling, query_radius, query_knn from sample_cuda import farthest_point_sampling, query_radius, query_knn
...@@ -62,8 +61,7 @@ def radius_query_edges(batch, ...@@ -62,8 +61,7 @@ def radius_query_edges(batch,
query_pos, query_pos,
radius, radius,
max_num_neighbors=128, max_num_neighbors=128,
include_self=True, include_self=True):
undirected=False):
if not pos.is_cuda: if not pos.is_cuda:
raise NotImplementedError raise NotImplementedError
assert pos.is_cuda and batch.is_cuda assert pos.is_cuda and batch.is_cuda
...@@ -91,19 +89,13 @@ def radius_query_edges(batch, ...@@ -91,19 +89,13 @@ def radius_query_edges(batch,
return col return col
edge_index = torch.stack([row, col], dim=0) edge_index = torch.stack([row, col], dim=0)
if undirected:
return to_undirected(edge_index, query_pos.size(0))
return edge_index return edge_index
def radius_graph(batch, def radius_graph(batch, pos, radius, max_num_neighbors=128,
pos, include_self=False):
radius,
max_num_neighbors=128,
include_self=False,
undirected=False):
return radius_query_edges(batch, pos, batch, pos, radius, return radius_query_edges(batch, pos, batch, pos, radius,
max_num_neighbors, include_self, undirected) max_num_neighbors, include_self)
def knn_query_edges(batch, def knn_query_edges(batch,
...@@ -111,8 +103,7 @@ def knn_query_edges(batch, ...@@ -111,8 +103,7 @@ def knn_query_edges(batch,
query_batch, query_batch,
query_pos, query_pos,
num_neighbors, num_neighbors,
include_self=True, include_self=True):
undirected=False):
if not pos.is_cuda: if not pos.is_cuda:
raise NotImplementedError raise NotImplementedError
assert pos.is_cuda and batch.is_cuda assert pos.is_cuda and batch.is_cuda
...@@ -140,11 +131,8 @@ def knn_query_edges(batch, ...@@ -140,11 +131,8 @@ def knn_query_edges(batch,
col = view[view != -1] col = view[view != -1]
edge_index = torch.stack([row, col], dim=0) edge_index = torch.stack([row, col], dim=0)
if undirected:
return to_undirected(edge_index, query_pos.size(0))
return edge_index return edge_index
def knn_graph(batch, pos, num_neighbors, include_self=False, undirected=False): def knn_graph(batch, pos, num_neighbors, include_self=False):
return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self, return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self)
undirected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment