Unverified Commit 972a9f13 authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Doc] Re-organize the code for dgl.geometry, and expose it in the doc (#2982)

* reorg and expose dgl.geometry

* fix lint

* fix test

* fix
parent e20d8953
.. _api-geometry:
dgl.geometry
=================================
.. automodule:: dgl.geometry
.. _api-geometry-farthest-point-sampler:
Farthest Point Sampler
-----------
Farthest point sampling is a greedy algorithm that samples from a point cloud
data iteratively. It starts from a random single sample of point. In each iteration,
it samples from the rest points that is the farthest from the set of sampled points.
.. autoclass:: farthest_point_sampler
.. _api-geometry-neighbor-matching:
Neighbor Matching
-----------------------------
Neighbor matching is an important module in the Graclus clustering algorithm.
.. autoclass:: neighbor_matching
...@@ -41,6 +41,7 @@ Welcome to Deep Graph Library Tutorials and Documentation ...@@ -41,6 +41,7 @@ Welcome to Deep Graph Library Tutorials and Documentation
api/python/dgl.DGLGraph api/python/dgl.DGLGraph
api/python/dgl.distributed api/python/dgl.distributed
api/python/dgl.function api/python/dgl.function
api/python/dgl.geometry
api/python/nn api/python/nn
api/python/nn.functional api/python/nn.functional
api/python/dgl.ops api/python/dgl.ops
......
...@@ -5,7 +5,7 @@ from torch.autograd import Variable ...@@ -5,7 +5,7 @@ from torch.autograd import Variable
import numpy as np import numpy as np
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from dgl.geometry.pytorch import FarthestPointSampler from dgl.geometry.pytorch import farthest_point_sampler
''' '''
Part of the code are adapted from Part of the code are adapted from
...@@ -167,7 +167,7 @@ class SAModule(nn.Module): ...@@ -167,7 +167,7 @@ class SAModule(nn.Module):
super(SAModule, self).__init__() super(SAModule, self).__init__()
self.group_all = group_all self.group_all = group_all
if not group_all: if not group_all:
self.fps = FarthestPointSampler(npoints) self.npoints = npoints
self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor) self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor)
self.message = RelativePositionMessage(n_neighbor) self.message = RelativePositionMessage(n_neighbor)
self.conv = PointNetConv(mlp_sizes, batch_size) self.conv = PointNetConv(mlp_sizes, batch_size)
...@@ -177,7 +177,7 @@ class SAModule(nn.Module): ...@@ -177,7 +177,7 @@ class SAModule(nn.Module):
if self.group_all: if self.group_all:
return self.conv.group_all(pos, feat) return self.conv.group_all(pos, feat)
centroids = self.fps(pos) centroids = farthest_point_sampler(pos, self.npoints)
g = self.frnn_graph(pos, centroids, feat) g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv) g.update_all(self.message, self.conv)
...@@ -197,7 +197,7 @@ class SAMSGModule(nn.Module): ...@@ -197,7 +197,7 @@ class SAMSGModule(nn.Module):
self.batch_size = batch_size self.batch_size = batch_size
self.group_size = len(radius_list) self.group_size = len(radius_list)
self.fps = FarthestPointSampler(npoints) self.npoints = npoints
self.frnn_graph_list = nn.ModuleList() self.frnn_graph_list = nn.ModuleList()
self.message_list = nn.ModuleList() self.message_list = nn.ModuleList()
self.conv_list = nn.ModuleList() self.conv_list = nn.ModuleList()
...@@ -208,7 +208,7 @@ class SAMSGModule(nn.Module): ...@@ -208,7 +208,7 @@ class SAMSGModule(nn.Module):
self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size)) self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))
def forward(self, pos, feat): def forward(self, pos, feat):
centroids = self.fps(pos) centroids = farthest_point_sampler(pos, self.npoints)
feat_res_list = [] feat_res_list = []
for i in range(self.group_size): for i in range(self.group_size):
......
"""Package for geometry common components.""" """The ``dgl.geometry`` package contains geometry operations:
import importlib
import sys
from ..backend import backend_name
* Farthest point sampling for point cloud sampling
def _load_backend(mod_name): * Neighbor matching module for graclus pooling
mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__]
for api, obj in mod.__dict__.items():
setattr(thismod, api, obj)
_load_backend(backend_name) .. note::
This package is experimental and the interfaces may be subject
to changes in future releases.
"""
from .fps import *
from .edge_coarsening import *
...@@ -6,7 +6,7 @@ from .. import backend as F ...@@ -6,7 +6,7 @@ from .. import backend as F
from .. import ndarray as nd from .. import ndarray as nd
def farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result): def _farthest_point_sampler(data, batch_size, sample_points, dist, start_idx, result):
r"""Farthest Point Sampler r"""Farthest Point Sampler
Parameters Parameters
......
"""Edge coarsening procedure used in Metis and Graclus, for pytorch""" """Edge coarsening procedure used in Metis and Graclus, for pytorch"""
# pylint: disable=no-member, invalid-name, W0613 # pylint: disable=no-member, invalid-name, W0613
import dgl from .. import remove_self_loop
import torch as th from .capi import _neighbor_matching
from ..capi import _neighbor_matching
__all__ = ['neighbor_matching'] __all__ = ['neighbor_matching']
class NeighborMatchingFn(th.autograd.Function):
r"""
Description
-----------
AutoGrad function for neighbor matching
"""
@staticmethod
def forward(ctx, gidx, num_nodes, e_weights, relabel_idx):
r"""
Description
-----------
Perform forward computation
"""
return _neighbor_matching(gidx, num_nodes, e_weights, relabel_idx)
@staticmethod
def backward(ctx):
r"""
Description
-----------
Perform backward computation
"""
pass # pylint: disable=unnecessary-pass
def neighbor_matching(graph, e_weights=None, relabel_idx=True): def neighbor_matching(graph, e_weights=None, relabel_idx=True):
r""" r"""
Description Description
...@@ -63,14 +36,25 @@ def neighbor_matching(graph, e_weights=None, relabel_idx=True): ...@@ -63,14 +36,25 @@ def neighbor_matching(graph, e_weights=None, relabel_idx=True):
relabel_idx : bool, optional relabel_idx : bool, optional
If true, relabel resulting node labels to have consecutive node ids. If true, relabel resulting node labels to have consecutive node ids.
default: :obj:`True` default: :obj:`True`
Examples
--------
The following example uses PyTorch backend.
>>> import torch, dgl
>>> from dgl.geometry import neighbor_matching
>>>
>>> g = dgl.graph(([0, 1, 1, 2], [1, 0, 2, 1]))
>>> res = neighbor_matching(g)
tensor([0, 1, 1])
""" """
assert graph.is_homogeneous, \ assert graph.is_homogeneous, \
"The graph used in graph node matching must be homogeneous" "The graph used in graph node matching must be homogeneous"
if e_weights is not None: if e_weights is not None:
graph.edata['e_weights'] = e_weights graph.edata['e_weights'] = e_weights
graph = dgl.remove_self_loop(graph) graph = remove_self_loop(graph)
e_weights = graph.edata['e_weights'] e_weights = graph.edata['e_weights']
graph.edata.pop('e_weights') graph.edata.pop('e_weights')
else: else:
graph = dgl.remove_self_loop(graph) graph = remove_self_loop(graph)
return NeighborMatchingFn.apply(graph._graph, graph.num_nodes(), e_weights, relabel_idx) return _neighbor_matching(graph._graph, graph.num_nodes(), e_weights, relabel_idx)
"""Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name
from .. import backend as F
from ..base import DGLError
from .capi import _farthest_point_sampler
__all__ = ['farthest_point_sampler']
def farthest_point_sampler(pos, npoints, start_idx=None):
"""Farthest Point Sampler without the need to compute all pairs of distance.
In each batch, the algorithm starts with the sample index specified by ``start_idx``.
Then for each point, we maintain the minimum to-sample distance.
Finally, we pick the point with the maximum such distance.
This process will be repeated for ``sample_points`` - 1 times.
Parameters
----------
pos : tensor
The positional tensor of shape (B, N, C)
npoints : int
The number of points to sample in each batch.
start_idx : int, optional
If given, appoint the index of the starting point,
otherwise randomly select a point as the start point.
(default: None)
Returns
-------
tensor of shape (B, npoints)
The sampled indices in each batch.
Examples
--------
The following exmaple uses PyTorch backend.
>>> import torch
>>> from dgl.geometry import farthest_point_sampler
>>> x = torch.rand((2, 10, 3))
>>> point_idx = farthest_point_sampler(x, 2)
>>> print(point_idx)
tensor([[5, 6],
[7, 8]])
"""
ctx = F.context(pos)
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = F.zeros((B * N), dtype=pos.dtype, ctx=ctx)
if start_idx is None:
start_idx = F.randint(shape=(B, ), dtype=F.int64, ctx=ctx, low=0, high=N-1)
else:
if start_idx >= N or start_idx < 0:
raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format(
N, start_idx))
start_idx = F.full_1d((B, ), start_idx, dtype=F.int64, ctx=ctx)
result = F.zeros((npoints * B), dtype=F.int64, ctx=ctx)
_farthest_point_sampler(pos, B, npoints, dist, start_idx, result)
return result.reshape(B, npoints)
"""Package for mxnet-specific Geometry modules."""
from .fps import *
from .edge_coarsening import *
"""Edge coarsening procedure used in Metis and Graclus, for mxnet"""
# pylint: disable=no-member, invalid-name, W0235
import dgl
import mxnet as mx
from ..capi import _neighbor_matching
__all__ = ['neighbor_matching']
class NeighborMatchingFn(mx.autograd.Function):
r"""
Description
-----------
AutoGrad function for neighbor matching
"""
def __init__(self, gidx, num_nodes, e_weights, relabel_idx):
super(NeighborMatchingFn, self).__init__()
self.gidx = gidx
self.num_nodes = num_nodes
self.e_weights = e_weights
self.relabel_idx = relabel_idx
def forward(self):
r"""
Description
-----------
Perform forward computation
"""
return _neighbor_matching(
self.gidx, self.num_nodes, self.e_weights, self.relabel_idx)
def backward(self):
r"""
Description
-----------
Perform backward computation
"""
pass # pylint: disable=unnecessary-pass
def neighbor_matching(graph, e_weights, relabel_idx):
r"""
Description
-----------
The neighbor matching procedure of edge coarsening in
`Metis <http://cacs.usc.edu/education/cs653/Karypis-METIS-SIAMJSC98.pdf>`__
and
`Graclus <https://www.cs.utexas.edu/users/inderjit/public_papers/multilevel_pami.pdf>`__
for homogeneous graph coarsening. This procedure keeps picking an unmarked
vertex and matching it with one its unmarked neighbors (that maximizes its
edge weight) until no match can be done.
If no edge weight is given, this procedure will randomly pick neighbor for each
vertex.
The GPU implementation is based on `A GPU Algorithm for Greedy Graph Matching
<http://www.staff.science.uu.nl/~bisse101/Articles/match12.pdf>`__
NOTE: The input graph must be bi-directed (undirected) graph. Call :obj:`dgl.to_bidirected`
if you are not sure your graph is bi-directed.
Parameters
----------
graph : DGLGraph
The input homogeneous graph.
edge_weight : mxnet.NDArray, optional
The edge weight tensor holding non-negative scalar weight for each edge.
default: :obj:`None`
relabel_idx : bool, optional
If true, relabel resulting node labels to have consecutive node ids.
default: :obj:`True`
"""
assert graph.is_homogeneous, \
"The graph used in graph node matching must be homogeneous"
if e_weights is not None:
graph.edata['e_weights'] = e_weights
graph = dgl.remove_self_loop(graph)
e_weights = graph.edata['e_weights']
graph.edata.pop('e_weights')
else:
graph = dgl.remove_self_loop(graph)
func = NeighborMatchingFn(graph._graph, graph.num_nodes(), e_weights, relabel_idx)
return func()
"""Farthest Point Sampler for mxnet Geometry package"""
#pylint: disable=no-member, invalid-name
from mxnet import nd
from mxnet.gluon import nn
import numpy as np
from ...base import DGLError
from ..capi import farthest_point_sampler
class FarthestPointSampler(nn.Block):
"""Farthest Point Sampler
In each batch, the algorithm starts with the sample index specified by ``start_idx``.
Then for each point, we maintain the minimum to-sample distance.
Finally, we pick the point with the maximum such distance.
This process will be repeated for ``sample_points`` - 1 times.
Parameters
----------
npoints : int
The number of points to sample in each batch.
"""
def __init__(self, npoints):
super(FarthestPointSampler, self).__init__()
self.npoints = npoints
def forward(self, pos, start_idx=None):
r"""Memory allocation and sampling
Parameters
----------
pos : tensor
The positional tensor of shape (B, N, C)
start_idx : int, optional
If given, appoint the index of the starting point,
otherwise randomly select a point as the start point.
(default: None)
Returns
-------
tensor of shape (B, self.npoints)
The sampled indices in each batch.
"""
ctx = pos.context
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = nd.zeros((B * N), dtype=pos.dtype, ctx=ctx)
if start_idx is None:
start_idx = nd.random.randint(0, N - 1, (B, ), dtype=np.int, ctx=ctx)
else:
if start_idx >= N or start_idx < 0:
raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format(
N, start_idx))
start_idx = nd.full((B, ), start_idx, dtype=np.int, ctx=ctx)
result = nd.zeros((self.npoints * B), dtype=np.int, ctx=ctx)
farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result)
return result.reshape(B, self.npoints)
"""Package for pytorch-specific Geometry modules."""
from .fps import *
from .edge_coarsening import *
"""Farthest Point Sampler for pytorch Geometry package"""
#pylint: disable=no-member, invalid-name
import torch as th
from torch import nn
from ...base import DGLError
from ..capi import farthest_point_sampler
class FarthestPointSampler(nn.Module):
"""Farthest Point Sampler without the need to compute all pairs of distance.
In each batch, the algorithm starts with the sample index specified by ``start_idx``.
Then for each point, we maintain the minimum to-sample distance.
Finally, we pick the point with the maximum such distance.
This process will be repeated for ``sample_points`` - 1 times.
Parameters
----------
npoints : int
The number of points to sample in each batch.
"""
def __init__(self, npoints):
super(FarthestPointSampler, self).__init__()
self.npoints = npoints
def forward(self, pos, start_idx=None):
r"""Memory allocation and sampling
Parameters
----------
pos : tensor
The positional tensor of shape (B, N, C)
start_idx : int, optional
If given, appoint the index of the starting point,
otherwise randomly select a point as the start point.
(default: None)
Returns
-------
tensor of shape (B, self.npoints)
The sampled indices in each batch.
"""
device = pos.device
B, N, C = pos.shape
pos = pos.reshape(-1, C)
dist = th.zeros((B * N), dtype=pos.dtype, device=device)
if start_idx is None:
start_idx = th.randint(0, N - 1, (B, ), dtype=th.long, device=device)
else:
if start_idx >= N or start_idx < 0:
raise DGLError("Invalid start_idx, expected 0 <= start_idx < {}, got {}".format(
N, start_idx))
start_idx = th.full((B, ), start_idx, dtype=th.long, device=device)
result = th.zeros((self.npoints * B), dtype=th.long, device=device)
farthest_point_sampler(pos, B, self.npoints, dist, start_idx, result)
return result.reshape(B, self.npoints)
import mxnet as mx import mxnet as mx
from dgl.geometry.mxnet import FarthestPointSampler from dgl.geometry import farthest_point_sampler
import backend as F import backend as F
import numpy as np import numpy as np
...@@ -12,8 +12,7 @@ def test_fps(): ...@@ -12,8 +12,7 @@ def test_fps():
ctx = F.ctx() ctx = F.ctx()
if F.gpu_ctx(): if F.gpu_ctx():
x = x.as_in_context(ctx) x = x.as_in_context(ctx)
fps = FarthestPointSampler(sample_points) res = farthest_point_sampler(x, sample_points)
res = fps(x)
assert res.shape[0] == batch_size assert res.shape[0] == batch_size
assert res.shape[1] == sample_points assert res.shape[1] == sample_points
assert res.sum() > 0 assert res.sum() > 0
......
...@@ -6,8 +6,7 @@ import pytest ...@@ -6,8 +6,7 @@ import pytest
import torch as th import torch as th
from dgl import DGLError from dgl import DGLError
from dgl.base import DGLWarning from dgl.base import DGLWarning
from dgl.geometry.pytorch import FarthestPointSampler from dgl.geometry import neighbor_matching, farthest_point_sampler
from dgl.geometry import neighbor_matching
from test_utils import parametrize_dtype from test_utils import parametrize_dtype
from test_utils.graph_cases import get_cases from test_utils.graph_cases import get_cases
...@@ -20,8 +19,7 @@ def test_fps(): ...@@ -20,8 +19,7 @@ def test_fps():
ctx = F.ctx() ctx = F.ctx()
if F.gpu_ctx(): if F.gpu_ctx():
x = x.to(ctx) x = x.to(ctx)
fps = FarthestPointSampler(sample_points) res = farthest_point_sampler(x, sample_points)
res = fps(x)
assert res.shape[0] == batch_size assert res.shape[0] == batch_size
assert res.shape[1] == sample_points assert res.shape[1] == sample_points
assert res.sum() > 0 assert res.sum() > 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment