Unverified Commit 14bffe97 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[NN] Renaming NearestNeighborGraph to KNNGraph (#802)

* initial commit

* second commit

* another commit

* change docstring

* migrating to dgl.nn

* fixes

* docs

* lint

* multiple fixes

* doc

* renaming nearest neighbor graph
parent dc19cd56
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import NearestNeighborGraph, EdgeConv from dgl.nn.pytorch import KNNGraph, EdgeConv
class Model(nn.Module): class Model(nn.Module):
def __init__(self, k, feature_dims, emb_dims, output_classes, input_dims=3, def __init__(self, k, feature_dims, emb_dims, output_classes, input_dims=3,
dropout_prob=0.5): dropout_prob=0.5):
super(Model, self).__init__() super(Model, self).__init__()
self.nng = NearestNeighborGraph(k) self.nng = KNNGraph(k)
self.conv = nn.ModuleList() self.conv = nn.ModuleList()
self.num_layers = len(feature_dims) self.num_layers = len(feature_dims)
......
"""Modules that transforms between graphs and between graph and tensors.""" """Modules that transforms between graphs and between graph and tensors."""
import torch.nn as nn import torch.nn as nn
from ...transform import nearest_neighbor_graph, segmented_nearest_neighbor_graph from ...transform import knn_graph, segmented_knn_graph
def pairwise_squared_distance(x): def pairwise_squared_distance(x):
''' '''
...@@ -11,7 +11,7 @@ def pairwise_squared_distance(x): ...@@ -11,7 +11,7 @@ def pairwise_squared_distance(x):
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2) return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)
class NearestNeighborGraph(nn.Module): class KNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of r"""Layer that transforms one point set into a graph, or a batch of
point sets with the same number of points into a union of those graphs. point sets with the same number of points into a union of those graphs.
...@@ -28,7 +28,7 @@ class NearestNeighborGraph(nn.Module): ...@@ -28,7 +28,7 @@ class NearestNeighborGraph(nn.Module):
The number of neighbors The number of neighbors
""" """
def __init__(self, k): def __init__(self, k):
super(NearestNeighborGraph, self).__init__() super(KNNGraph, self).__init__()
self.k = k self.k = k
#pylint: disable=invalid-name #pylint: disable=invalid-name
...@@ -46,10 +46,10 @@ class NearestNeighborGraph(nn.Module): ...@@ -46,10 +46,10 @@ class NearestNeighborGraph(nn.Module):
------- -------
A DGLGraph with no features. A DGLGraph with no features.
""" """
return nearest_neighbor_graph(x, self.k) return knn_graph(x, self.k)
class SegmentedNearestNeighborGraph(nn.Module): class SegmentedKNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of r"""Layer that transforms one point set into a graph, or a batch of
point sets with different number of points into a union of those graphs. point sets with different number of points into a union of those graphs.
...@@ -80,7 +80,7 @@ class SegmentedNearestNeighborGraph(nn.Module): ...@@ -80,7 +80,7 @@ class SegmentedNearestNeighborGraph(nn.Module):
- A DGLGraph with no features. - A DGLGraph with no features.
""" """
def __init__(self, k): def __init__(self, k):
super(SegmentedNearestNeighborGraph, self).__init__() super(SegmentedKNNGraph, self).__init__()
self.k = k self.k = k
#pylint: disable=invalid-name #pylint: disable=invalid-name
...@@ -100,4 +100,4 @@ class SegmentedNearestNeighborGraph(nn.Module): ...@@ -100,4 +100,4 @@ class SegmentedNearestNeighborGraph(nn.Module):
------- -------
A DGLGraph with no features. A DGLGraph with no features.
""" """
return segmented_nearest_neighbor_graph(x, self.k, segs) return segmented_knn_graph(x, self.k, segs)
...@@ -10,7 +10,7 @@ from .batched_graph import BatchedDGLGraph, unbatch ...@@ -10,7 +10,7 @@ from .batched_graph import BatchedDGLGraph, unbatch
__all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected', __all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected',
'laplacian_lambda_max', 'nearest_neighbor_graph', 'segmented_nearest_neighbor_graph'] 'laplacian_lambda_max', 'knn_graph', 'segmented_knn_graph']
def pairwise_squared_distance(x): def pairwise_squared_distance(x):
...@@ -23,7 +23,7 @@ def pairwise_squared_distance(x): ...@@ -23,7 +23,7 @@ def pairwise_squared_distance(x):
return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2) return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2)
#pylint: disable=invalid-name #pylint: disable=invalid-name
def nearest_neighbor_graph(x, k): def knn_graph(x, k):
"""Transforms the given point set to a directed graph, whose coordinates """Transforms the given point set to a directed graph, whose coordinates
are given as a matrix. The predecessors of each point are its k-nearest are given as a matrix. The predecessors of each point are its k-nearest
neighbors. neighbors.
...@@ -69,7 +69,7 @@ def nearest_neighbor_graph(x, k): ...@@ -69,7 +69,7 @@ def nearest_neighbor_graph(x, k):
return g return g
#pylint: disable=invalid-name #pylint: disable=invalid-name
def segmented_nearest_neighbor_graph(x, k, segs): def segmented_knn_graph(x, k, segs):
"""Transforms the given point set to a directed graph, whose coordinates """Transforms the given point set to a directed graph, whose coordinates
are given as a matrix. The predecessors of each point are its k-nearest are given as a matrix. The predecessors of each point are its k-nearest
neighbors. neighbors.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment