Unverified Commit 14bffe97 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[NN] Renaming NearestNeighborGraph to KNNGraph (#802)

* initial commit

* second commit

* another commit

* change docstring

* migrating to dgl.nn

* fixes

* docs

* lint

* multiple fixes

* doc

* renaming nearest neighbor graph
parent dc19cd56
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import NearestNeighborGraph, EdgeConv
from dgl.nn.pytorch import KNNGraph, EdgeConv
class Model(nn.Module):
def __init__(self, k, feature_dims, emb_dims, output_classes, input_dims=3,
dropout_prob=0.5):
super(Model, self).__init__()
self.nng = NearestNeighborGraph(k)
self.nng = KNNGraph(k)
self.conv = nn.ModuleList()
self.num_layers = len(feature_dims)
......
"""Modules that transforms between graphs and between graph and tensors."""
import torch.nn as nn
from ...transform import nearest_neighbor_graph, segmented_nearest_neighbor_graph
from ...transform import knn_graph, segmented_knn_graph
def pairwise_squared_distance(x):
'''
......@@ -11,7 +11,7 @@ def pairwise_squared_distance(x):
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)
class NearestNeighborGraph(nn.Module):
class KNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of
point sets with the same number of points into a union of those graphs.
......@@ -28,7 +28,7 @@ class NearestNeighborGraph(nn.Module):
The number of neighbors
"""
def __init__(self, k):
super(NearestNeighborGraph, self).__init__()
super(KNNGraph, self).__init__()
self.k = k
#pylint: disable=invalid-name
......@@ -46,10 +46,10 @@ class NearestNeighborGraph(nn.Module):
-------
A DGLGraph with no features.
"""
return nearest_neighbor_graph(x, self.k)
return knn_graph(x, self.k)
class SegmentedNearestNeighborGraph(nn.Module):
class SegmentedKNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of
point sets with different number of points into a union of those graphs.
......@@ -80,7 +80,7 @@ class SegmentedNearestNeighborGraph(nn.Module):
- A DGLGraph with no features.
"""
def __init__(self, k):
super(SegmentedNearestNeighborGraph, self).__init__()
super(SegmentedKNNGraph, self).__init__()
self.k = k
#pylint: disable=invalid-name
......@@ -100,4 +100,4 @@ class SegmentedNearestNeighborGraph(nn.Module):
-------
A DGLGraph with no features.
"""
return segmented_nearest_neighbor_graph(x, self.k, segs)
return segmented_knn_graph(x, self.k, segs)
......@@ -10,7 +10,7 @@ from .batched_graph import BatchedDGLGraph, unbatch
__all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected',
'laplacian_lambda_max', 'nearest_neighbor_graph', 'segmented_nearest_neighbor_graph']
'laplacian_lambda_max', 'knn_graph', 'segmented_knn_graph']
def pairwise_squared_distance(x):
......@@ -23,7 +23,7 @@ def pairwise_squared_distance(x):
return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2)
#pylint: disable=invalid-name
def nearest_neighbor_graph(x, k):
def knn_graph(x, k):
"""Transforms the given point set to a directed graph, whose coordinates
are given as a matrix. The predecessors of each point are its k-nearest
neighbors.
......@@ -69,7 +69,7 @@ def nearest_neighbor_graph(x, k):
return g
#pylint: disable=invalid-name
def segmented_nearest_neighbor_graph(x, k, segs):
def segmented_knn_graph(x, k, segs):
"""Transforms the given point set to a directed graph, whose coordinates
are given as a matrix. The predecessors of each point are its k-nearest
neighbors.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment