Commit 66eef3b3 authored by rusty1s's avatar rusty1s
Browse files

removed undirected call to pytorch geometric

parent 50909651
import torch
from torch_geometric.utils import to_undirected
from sample_cuda import farthest_point_sampling, query_radius, query_knn
......@@ -62,8 +61,7 @@ def radius_query_edges(batch,
query_pos,
radius,
max_num_neighbors=128,
include_self=True,
undirected=False):
include_self=True):
if not pos.is_cuda:
raise NotImplementedError
assert pos.is_cuda and batch.is_cuda
......@@ -91,19 +89,13 @@ def radius_query_edges(batch,
return col
edge_index = torch.stack([row, col], dim=0)
if undirected:
return to_undirected(edge_index, query_pos.size(0))
return edge_index
def radius_graph(batch,
pos,
radius,
max_num_neighbors=128,
include_self=False,
undirected=False):
def radius_graph(batch, pos, radius, max_num_neighbors=128,
include_self=False):
return radius_query_edges(batch, pos, batch, pos, radius,
max_num_neighbors, include_self, undirected)
max_num_neighbors, include_self)
def knn_query_edges(batch,
......@@ -111,8 +103,7 @@ def knn_query_edges(batch,
query_batch,
query_pos,
num_neighbors,
include_self=True,
undirected=False):
include_self=True):
if not pos.is_cuda:
raise NotImplementedError
assert pos.is_cuda and batch.is_cuda
......@@ -140,11 +131,8 @@ def knn_query_edges(batch,
col = view[view != -1]
edge_index = torch.stack([row, col], dim=0)
if undirected:
return to_undirected(edge_index, query_pos.size(0))
return edge_index
def knn_graph(batch, pos, num_neighbors, include_self=False, undirected=False):
return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self,
undirected)
def knn_graph(batch, pos, num_neighbors, include_self=False):
return knn_query_edges(batch, pos, batch, pos, num_neighbors, include_self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment