Unverified Commit 684f66b7 authored by ndickson-nvidia's avatar ndickson-nvidia Committed by GitHub
Browse files

[Feature] Added exclude_self and output_batch to knn graph construction...


[Feature] Added exclude_self and output_batch to knn graph construction (Issues #4323 #4316) (#4389)

* * Added "exclude_self" and "output_batch" options to knn_graph and segmented_knn_graph
* Updated out-of-date comments on remove_edges and remove_self_loop, since they now preserve batch information

* * Changed defaults on new knn_graph and segmented_knn_graph function parameters, for compatibility; pytorch/test_geometry.py was failing

* * Added test to ensure dgl.remove_self_loop function correctly updates batch information

* * Added new knn_graph and segmented_knn_graph parameters to dgl.nn.KNNGraph and dgl.nn.SegmentedKNNGraph

* * Formatting

* * Oops, I missed the one in segmented_knn_graph when I fixed the similar thing in knn_graph

* * Fixed edge case handling when invalid k specified, since it still needs to be handled consistently for tests to pass
* Fixed context of batch info, since it must match the context of the input position data for remove_self_loop to succeed

* * Fixed batch info resulting from knn_graph when output_batch is true, for case of 3D input tensor, representing multiple segments

* * Added testing of new exclude_self and output_batch parameters on knn_graph and segmented_knn_graph, and their wrappers, KNNGraph and SegmentedKNNGraph, into the test_knn_cuda test

* * Added doc comments for new parameters

* * Added correct handling for uncommon case of k or more coincident points when excluding self edges in knn_graph and segmented_knn_graph
* Added test cases for more than k coincident points

* * Updated doc comments for output_batch parameters for clarity

* * Linter formatting fixes

* * Extracted out common function for test_knn_cpu and test_knn_cuda, to add the new test cases to test_knn_cpu

* * Rewording in doc comments

* * Removed output_batch parameter from knn_graph and segmented_knn_graph, in favour of always setting the batch information, except in knn_graph if x is a 2D tensor
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 1c9d2a03
...@@ -13,7 +13,7 @@ def pairwise_squared_distance(x): ...@@ -13,7 +13,7 @@ def pairwise_squared_distance(x):
class KNNGraph(nn.Module): class KNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of r"""Layer that transforms one point set into a graph, or a batch of
point sets with the same number of points into a union of those graphs. point sets with the same number of points into a batched union of those graphs.
The KNNGraph is implemented in the following steps: The KNNGraph is implemented in the following steps:
...@@ -63,7 +63,8 @@ class KNNGraph(nn.Module): ...@@ -63,7 +63,8 @@ class KNNGraph(nn.Module):
self.k = k self.k = k
#pylint: disable=invalid-name #pylint: disable=invalid-name
def forward(self, x, algorithm='bruteforce-blas', dist='euclidean'): def forward(self, x, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
r""" r"""
Forward computation. Forward computation.
...@@ -113,18 +114,23 @@ class KNNGraph(nn.Module): ...@@ -113,18 +114,23 @@ class KNNGraph(nn.Module):
:math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`. :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance. * 'cosine': Use cosine distance.
(default: 'euclidean') (default: 'euclidean')
exclude_self : bool, optional
If True, the output graph will not contain self loop edges, and each node will not
be counted as one of its own k neighbors. If False, the output graph will contain
self loop edges, and a node will be counted as one of its own k neighbors.
Returns Returns
------- -------
DGLGraph DGLGraph
A DGLGraph without features. A DGLGraph without features.
""" """
return knn_graph(x, self.k, algorithm=algorithm, dist=dist) return knn_graph(x, self.k, algorithm=algorithm, dist=dist,
exclude_self=exclude_self)
class SegmentedKNNGraph(nn.Module): class SegmentedKNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of r"""Layer that transforms one point set into a graph, or a batch of
point sets with different number of points into a union of those graphs. point sets with different number of points into a batched union of those graphs.
If a batch of point sets is provided, then the point :math:`j` in the point If a batch of point sets is provided, then the point :math:`j` in the point
set :math:`i` is mapped to graph node ID: set :math:`i` is mapped to graph node ID:
...@@ -171,7 +177,8 @@ class SegmentedKNNGraph(nn.Module): ...@@ -171,7 +177,8 @@ class SegmentedKNNGraph(nn.Module):
self.k = k self.k = k
#pylint: disable=invalid-name #pylint: disable=invalid-name
def forward(self, x, segs, algorithm='bruteforce-blas', dist='euclidean'): def forward(self, x, segs, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
r"""Forward computation. r"""Forward computation.
Parameters Parameters
...@@ -222,14 +229,19 @@ class SegmentedKNNGraph(nn.Module): ...@@ -222,14 +229,19 @@ class SegmentedKNNGraph(nn.Module):
:math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`. :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance. * 'cosine': Use cosine distance.
(default: 'euclidean') (default: 'euclidean')
exclude_self : bool, optional
If True, the output graph will not contain self loop edges, and each node will not
be counted as one of its own k neighbors. If False, the output graph will contain
self loop edges, and a node will be counted as one of its own k neighbors.
Returns Returns
------- -------
DGLGraph DGLGraph
A DGLGraph without features. A batched DGLGraph without features.
""" """
return segmented_knn_graph(x, self.k, segs, algorithm=algorithm, dist=dist) return segmented_knn_graph(x, self.k, segs, algorithm=algorithm, dist=dist,
exclude_self=exclude_self)
class RadiusGraph(nn.Module): class RadiusGraph(nn.Module):
......
...@@ -41,6 +41,7 @@ from ..partition import partition_graph_with_halo ...@@ -41,6 +41,7 @@ from ..partition import partition_graph_with_halo
from ..partition import metis_partition from ..partition import metis_partition
from .. import subgraph from .. import subgraph
from .. import function from .. import function
from ..sampling.neighbor import sample_neighbors
# TO BE DEPRECATED # TO BE DEPRECATED
from .._deprecate.graph import DGLGraph as DGLGraphStale from .._deprecate.graph import DGLGraph as DGLGraphStale
...@@ -97,7 +98,8 @@ def pairwise_squared_distance(x): ...@@ -97,7 +98,8 @@ def pairwise_squared_distance(x):
return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2) return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2)
#pylint: disable=invalid-name #pylint: disable=invalid-name
def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'): def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN) r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN)
and return. and return.
...@@ -110,8 +112,8 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'): ...@@ -110,8 +112,8 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
of each point are its k-nearest neighbors measured by the chosen distance. of each point are its k-nearest neighbors measured by the chosen distance.
If :attr:`x` is a 3D tensor, then each submatrix will be transformed If :attr:`x` is a 3D tensor, then each submatrix will be transformed
into a separate graph. DGL then composes the graphs into a large into a separate graph. DGL then composes the graphs into a large batched
graph of multiple connected components. graph of multiple (:math:`shape(x)[0]`) connected components.
See :doc:`the benchmark <../api/python/knn_benchmark>` for a complete benchmark result. See :doc:`the benchmark <../api/python/knn_benchmark>` for a complete benchmark result.
...@@ -164,6 +166,10 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'): ...@@ -164,6 +166,10 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`. * 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance. * 'cosine': Use cosine distance.
(default: 'euclidean') (default: 'euclidean')
exclude_self : bool, optional
If True, the output graph will not contain self loop edges, and each node will not
be counted as one of its own k neighbors. If False, the output graph will contain
self loop edges, and a node will be counted as one of its own k neighbors.
Returns Returns
------- -------
...@@ -205,26 +211,58 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'): ...@@ -205,26 +211,58 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
(tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7]), (tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7]),
tensor([0, 1, 1, 2, 3, 0, 2, 3, 4, 5, 6, 7, 4, 6, 5, 7])) tensor([0, 1, 1, 2, 3, 0, 2, 3, 4, 5, 6, 7, 4, 6, 5, 7]))
""" """
if exclude_self:
# add 1 to k, for the self edge, since it will be removed
k = k + 1
# check invalid k # check invalid k
if k <= 0: if k <= 0:
raise DGLError("Invalid k value. expect k > 0, got k = {}".format(k)) raise DGLError("Invalid k value. expect k > 0, got k = {}".format(k))
# check empty point set # check empty point set
if F.shape(x)[0] == 0: x_size = tuple(F.shape(x))
if x_size[0] == 0:
raise DGLError("Find empty point set") raise DGLError("Find empty point set")
d = F.ndim(x)
x_seg = x_size[0] * [x_size[1]] if d == 3 else [x_size[0]]
if algorithm == 'bruteforce-blas': if algorithm == 'bruteforce-blas':
return _knn_graph_blas(x, k, dist=dist) result = _knn_graph_blas(x, k, dist=dist)
else: else:
if F.ndim(x) == 3: if d == 3:
x_size = tuple(F.shape(x))
x = F.reshape(x, (x_size[0] * x_size[1], x_size[2])) x = F.reshape(x, (x_size[0] * x_size[1], x_size[2]))
x_seg = x_size[0] * [x_size[1]]
else:
x_seg = [F.shape(x)[0]]
out = knn(k, x, x_seg, algorithm=algorithm, dist=dist) out = knn(k, x, x_seg, algorithm=algorithm, dist=dist)
row, col = out[1], out[0] row, col = out[1], out[0]
return convert.graph((row, col)) result = convert.graph((row, col))
if d == 3:
# set batch information if x is 3D
num_nodes = F.tensor(x_seg, dtype=F.int64).to(F.context(x))
result.set_batch_num_nodes(num_nodes)
# if any segment is too small for k, all algorithms reduce k for all segments
clamped_k = min(k, np.min(x_seg))
result.set_batch_num_edges(clamped_k*num_nodes)
if exclude_self:
# remove_self_loop will update batch_num_edges as needed
result = remove_self_loop(result)
# If there were more than k(+1) coincident points, there may not have been self loops on
# all nodes, in which case there would still be one too many out edges on some nodes.
# However, if every node had a self edge, the common case, every node would still have the
# same degree as each other, so we can check that condition easily.
# The -1 is for the self edge removal.
clamped_k = min(k, np.min(x_seg)) - 1
if result.num_edges() != clamped_k*result.num_nodes():
# edges on any nodes with too high degree should all be length zero,
# so pick an arbitrary one to remove from each such node
degrees = result.in_degrees()
node_indices = F.nonzero_1d(degrees > clamped_k)
edges_to_remove_graph = sample_neighbors(result, node_indices, 1, edge_dir='in')
edge_ids = edges_to_remove_graph.edata[EID]
result = remove_edges(result, edge_ids)
return result
def _knn_graph_blas(x, k, dist='euclidean'): def _knn_graph_blas(x, k, dist='euclidean'):
r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN). r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN).
...@@ -279,7 +317,8 @@ def _knn_graph_blas(x, k, dist='euclidean'): ...@@ -279,7 +317,8 @@ def _knn_graph_blas(x, k, dist='euclidean'):
return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,)))) return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
#pylint: disable=invalid-name #pylint: disable=invalid-name
def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean'): def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
r"""Construct multiple graphs from multiple sets of points according to r"""Construct multiple graphs from multiple sets of points according to
k-nearest-neighbor (KNN) and return. k-nearest-neighbor (KNN) and return.
...@@ -290,7 +329,7 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean ...@@ -290,7 +329,7 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
function constructs a KNN graph for each point set, where the predecessors function constructs a KNN graph for each point set, where the predecessors
of each point are its k-nearest neighbors measured by the Euclidean distance. of each point are its k-nearest neighbors measured by the Euclidean distance.
DGL then composes all KNN graphs DGL then composes all KNN graphs
into a graph with multiple connected components. into a batched graph with multiple (:math:`len(segs)`) connected components.
Parameters Parameters
---------- ----------
...@@ -339,11 +378,15 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean ...@@ -339,11 +378,15 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`. * 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance. * 'cosine': Use cosine distance.
(default: 'euclidean') (default: 'euclidean')
exclude_self : bool, optional
If True, the output graph will not contain self loop edges, and each node will not
be counted as one of its own k neighbors. If False, the output graph will contain
self loop edges, and a node will be counted as one of its own k neighbors.
Returns Returns
------- -------
DGLGraph DGLGraph
The graph. The node IDs are in the same order as :attr:`x`. The batched graph. The node IDs are in the same order as :attr:`x`.
Examples Examples
-------- --------
...@@ -372,6 +415,10 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean ...@@ -372,6 +415,10 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
(tensor([0, 0, 1, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6]), (tensor([0, 0, 1, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6]),
tensor([0, 1, 0, 1, 2, 2, 3, 5, 4, 6, 3, 5, 4, 6])) tensor([0, 1, 0, 1, 2, 2, 3, 5, 4, 6, 3, 5, 4, 6]))
""" """
if exclude_self:
# add 1 to k, for the self edge, since it will be removed
k = k + 1
# check invalid k # check invalid k
if k <= 0: if k <= 0:
raise DGLError("Invalid k value. expect k > 0, got k = {}".format(k)) raise DGLError("Invalid k value. expect k > 0, got k = {}".format(k))
...@@ -381,11 +428,38 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean ...@@ -381,11 +428,38 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
raise DGLError("Find empty point set") raise DGLError("Find empty point set")
if algorithm == 'bruteforce-blas': if algorithm == 'bruteforce-blas':
return _segmented_knn_graph_blas(x, k, segs, dist=dist) result = _segmented_knn_graph_blas(x, k, segs, dist=dist)
else: else:
out = knn(k, x, segs, algorithm=algorithm, dist=dist) out = knn(k, x, segs, algorithm=algorithm, dist=dist)
row, col = out[1], out[0] row, col = out[1], out[0]
return convert.graph((row, col)) result = convert.graph((row, col))
num_nodes = F.tensor(segs, dtype=F.int64).to(F.context(x))
result.set_batch_num_nodes(num_nodes)
# if any segment is too small for k, all algorithms reduce k for all segments
clamped_k = min(k, np.min(segs))
result.set_batch_num_edges(clamped_k*num_nodes)
if exclude_self:
# remove_self_loop will update batch_num_edges as needed
result = remove_self_loop(result)
# If there were more than k(+1) coincident points, there may not have been self loops on
# all nodes, in which case there would still be one too many out edges on some nodes.
# However, if every node had a self edge, the common case, every node would still have the
# same degree as each other, so we can check that condition easily.
# The -1 is for the self edge removal.
clamped_k = min(k, np.min(segs)) - 1
if result.num_edges() != clamped_k*result.num_nodes():
# edges on any nodes with too high degree should all be length zero,
# so pick an arbitrary one to remove from each such node
degrees = result.in_degrees()
node_indices = F.nonzero_1d(degrees > clamped_k)
edges_to_remove_graph = sample_neighbors(result, node_indices, 1, edge_dir='in')
edge_ids = edges_to_remove_graph.edata[EID]
result = remove_edges(result, edge_ids)
return result
def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'): def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'):
r"""Construct multiple graphs from multiple sets of points according to r"""Construct multiple graphs from multiple sets of points according to
...@@ -1638,11 +1712,7 @@ def remove_edges(g, eids, etype=None, store_ids=False): ...@@ -1638,11 +1712,7 @@ def remove_edges(g, eids, etype=None, store_ids=False):
Notes Notes
----- -----
This function preserves the batch information.
This function discards the batch information. Please use
:func:`dgl.DGLGraph.set_batch_num_nodes`
and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph
to maintain the information.
Examples Examples
-------- --------
...@@ -1910,10 +1980,7 @@ def remove_self_loop(g, etype=None): ...@@ -1910,10 +1980,7 @@ def remove_self_loop(g, etype=None):
If a node has multiple self-loops, remove them all. Do nothing for nodes without If a node has multiple self-loops, remove them all. Do nothing for nodes without
self-loops. self-loops.
This function discards the batch information. Please use This function preserves the batch information.
:func:`dgl.DGLGraph.set_batch_num_nodes`
and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph
to maintain the information.
Examples Examples
--------- ---------
......
...@@ -1797,6 +1797,16 @@ def test_remove_selfloop(idtype): ...@@ -1797,6 +1797,16 @@ def test_remove_selfloop(idtype):
raise_error = True raise_error = True
assert raise_error assert raise_error
# batch information
g = dgl.graph(([0, 0, 0, 1, 3, 3, 4], [1, 0, 0, 2, 3, 4, 4]), idtype=idtype, device=F.ctx())
g.set_batch_num_nodes(F.tensor([3, 2], dtype=F.int64))
g.set_batch_num_edges(F.tensor([4, 3], dtype=F.int64))
g = dgl.remove_self_loop(g)
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 3
assert F.array_equal(g.batch_num_nodes(), F.tensor([3, 2], dtype=F.int64))
assert F.array_equal(g.batch_num_edges(), F.tensor([2, 1], dtype=F.int64))
@parametrize_idtype @parametrize_idtype
def test_reorder_graph(idtype): def test_reorder_graph(idtype):
......
...@@ -36,11 +36,8 @@ def test_fps_start_idx(): ...@@ -36,11 +36,8 @@ def test_fps_start_idx():
res = farthest_point_sampler(x, sample_points, start_idx=0) res = farthest_point_sampler(x, sample_points, start_idx=0)
assert th.any(res[:, 0] == 0) assert th.any(res[:, 0] == 0)
def _test_knn_common(device, algorithm, dist, exclude_self):
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'kd-tree']) x = th.randn(8, 3).to(device)
@pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
def test_knn_cpu(algorithm, dist):
x = th.randn(8, 3).to(F.cpu())
kg = dgl.nn.KNNGraph(3) kg = dgl.nn.KNNGraph(3)
if dist == 'euclidean': if dist == 'euclidean':
d = th.cdist(x, x).to(F.cpu()) d = th.cdist(x, x).to(F.cpu())
...@@ -49,136 +46,126 @@ def test_knn_cpu(algorithm, dist): ...@@ -49,136 +46,126 @@ def test_knn_cpu(algorithm, dist):
tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True))) tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True)))
d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu()) d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu())
def check_knn(g, x, start, end, k): def check_knn(g, x, start, end, k, exclude_self, check_indices=True):
assert g.device == x.device assert g.device == x.device
g = g.to(F.cpu())
for v in range(start, end): for v in range(start, end):
src, _ = g.in_edges(v) src, _ = g.in_edges(v)
src = set(src.numpy()) src = set(src.numpy())
assert len(src) == k
if check_indices:
i = v - start i = v - start
src_ans = set(th.topk(d[start:end, start:end][i], k, largest=False)[1].numpy() + start) src_ans = set(th.topk(d[start:end, start:end][i], k + (1 if exclude_self else 0), largest=False)[1].numpy() + start)
if exclude_self:
# remove self
src_ans.remove(v)
assert src == src_ans assert src == src_ans
def check_batch(g, k, expected_batch_info):
assert F.array_equal(g.batch_num_nodes(), F.tensor(expected_batch_info))
assert F.array_equal(g.batch_num_edges(), k*F.tensor(expected_batch_info))
# check knn with 2d input # check knn with 2d input
g = kg(x, algorithm, dist) g = kg(x, algorithm, dist, exclude_self)
check_knn(g, x, 0, 8, 3) check_knn(g, x, 0, 8, 3, exclude_self)
check_batch(g, 3, [8])
# check knn with 3d input # check knn with 3d input
g = kg(x.view(2, 4, 3), algorithm, dist) g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)
check_knn(g, x, 0, 4, 3) check_knn(g, x, 0, 4, 3, exclude_self)
check_knn(g, x, 4, 8, 3) check_knn(g, x, 4, 8, 3, exclude_self)
check_batch(g, 3, [4, 4])
# check segmented knn # check segmented knn
kg = dgl.nn.SegmentedKNNGraph(3) # there are only 2 edges per node possible when exclude_self with 3 nodes in the segment
g = kg(x, [3, 5], algorithm, dist) # and this test case isn't supposed to warn, so limit it when exclude_self is True
check_knn(g, x, 0, 3, 3) adjusted_k = 3 - (1 if exclude_self else 0)
check_knn(g, x, 3, 8, 3) kg = dgl.nn.SegmentedKNNGraph(adjusted_k)
g = kg(x, [3, 5], algorithm, dist, exclude_self)
check_knn(g, x, 0, 3, adjusted_k, exclude_self)
check_knn(g, x, 3, 8, adjusted_k, exclude_self)
check_batch(g, adjusted_k, [3, 5])
# check k > num_points # check k > num_points
kg = dgl.nn.KNNGraph(10) kg = dgl.nn.KNNGraph(10)
with pytest.warns(DGLWarning): with pytest.warns(DGLWarning):
g = kg(x, algorithm, dist) g = kg(x, algorithm, dist, exclude_self)
check_knn(g, x, 0, 8, 8) # there are only 7 edges per node possible when exclude_self with 8 nodes total
adjusted_k = 8 - (1 if exclude_self else 0)
check_knn(g, x, 0, 8, adjusted_k, exclude_self)
check_batch(g, adjusted_k, [8])
with pytest.warns(DGLWarning): with pytest.warns(DGLWarning):
g = kg(x.view(2, 4, 3), algorithm, dist) g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)
check_knn(g, x, 0, 4, 4) # there are only 3 edges per node possible when exclude_self with 4 nodes per segment
check_knn(g, x, 4, 8, 4) adjusted_k = 4 - (1 if exclude_self else 0)
check_knn(g, x, 0, 4, adjusted_k, exclude_self)
check_knn(g, x, 4, 8, adjusted_k, exclude_self)
check_batch(g, adjusted_k, [4, 4])
kg = dgl.nn.SegmentedKNNGraph(5) kg = dgl.nn.SegmentedKNNGraph(5)
with pytest.warns(DGLWarning): with pytest.warns(DGLWarning):
g = kg(x, [3, 5], algorithm, dist) g = kg(x, [3, 5], algorithm, dist, exclude_self)
check_knn(g, x, 0, 3, 3) # there are only 2 edges per node possible when exclude_self in the segment with
check_knn(g, x, 3, 8, 3) # only 3 nodes, and the current implementation reduces k for all segments
# in that case
adjusted_k = 3 - (1 if exclude_self else 0)
check_knn(g, x, 0, 3, adjusted_k, exclude_self)
check_knn(g, x, 3, 8, adjusted_k, exclude_self)
check_batch(g, adjusted_k, [3, 5])
# check k == 0 # check k == 0
kg = dgl.nn.KNNGraph(0) # that's valid for exclude_self, but -1 is not, so check -1 instead for exclude_self
adjusted_k = 0 - (1 if exclude_self else 0)
kg = dgl.nn.KNNGraph(adjusted_k)
with pytest.raises(DGLError): with pytest.raises(DGLError):
g = kg(x, algorithm, dist) g = kg(x, algorithm, dist, exclude_self)
kg = dgl.nn.SegmentedKNNGraph(0) kg = dgl.nn.SegmentedKNNGraph(adjusted_k)
with pytest.raises(DGLError): with pytest.raises(DGLError):
g = kg(x, [3, 5], algorithm, dist) g = kg(x, [3, 5], algorithm, dist, exclude_self)
# check empty # check empty
x_empty = th.tensor([]) x_empty = th.tensor([])
kg = dgl.nn.KNNGraph(3) kg = dgl.nn.KNNGraph(3)
with pytest.raises(DGLError): with pytest.raises(DGLError):
g = kg(x_empty, algorithm, dist) g = kg(x_empty, algorithm, dist, exclude_self)
kg = dgl.nn.SegmentedKNNGraph(3) kg = dgl.nn.SegmentedKNNGraph(3)
with pytest.raises(DGLError): with pytest.raises(DGLError):
g = kg(x_empty, [3, 5], algorithm, dist) g = kg(x_empty, [3, 5], algorithm, dist, exclude_self)
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'bruteforce-sharemem']) # check all coincident points
@pytest.mark.parametrize('dist', ['euclidean', 'cosine']) x = th.zeros((20, 3)).to(device)
def test_knn_cuda(algorithm, dist):
if not th.cuda.is_available():
return
x = th.randn(8, 3).to(F.cuda())
kg = dgl.nn.KNNGraph(3) kg = dgl.nn.KNNGraph(3)
if dist == 'euclidean': g = kg(x, algorithm, dist, exclude_self)
d = th.cdist(x, x).to(F.cpu()) # different algorithms may break the tie differently, so don't check the indices
else: check_knn(g, x, 0, 20, 3, exclude_self, False)
x = x + th.randn(1).item() check_batch(g, 3, [20])
tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True)))
d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu())
def check_knn(g, x, start, end, k):
assert g.device == x.device
g = g.to(F.cpu())
for v in range(start, end):
src, _ = g.in_edges(v)
src = set(src.numpy())
i = v - start
src_ans = set(th.topk(d[start:end, start:end][i], k, largest=False)[1].numpy() + start)
assert src == src_ans
# check knn with 2d input
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 3)
# check knn with 3d input # check all coincident points
g = kg(x.view(2, 4, 3), algorithm, dist)
check_knn(g, x, 0, 4, 3)
check_knn(g, x, 4, 8, 3)
# check segmented knn
kg = dgl.nn.SegmentedKNNGraph(3) kg = dgl.nn.SegmentedKNNGraph(3)
g = kg(x, [3, 5], algorithm, dist) g = kg(x, [4, 7, 5, 4], algorithm, dist, exclude_self)
check_knn(g, x, 0, 3, 3) # different algorithms may break the tie differently, so don't check the indices
check_knn(g, x, 3, 8, 3) check_knn(g, x, 0, 4, 3, exclude_self, False)
check_knn(g, x, 4, 11, 3, exclude_self, False)
check_knn(g, x, 11, 16, 3, exclude_self, False)
check_knn(g, x, 16, 20, 3, exclude_self, False)
check_batch(g, 3, [4, 7, 5, 4])
# check k > num_points
kg = dgl.nn.KNNGraph(10)
with pytest.warns(DGLWarning):
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 8)
with pytest.warns(DGLWarning): @pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'kd-tree'])
g = kg(x.view(2, 4, 3), algorithm, dist) @pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
check_knn(g, x, 0, 4, 4) @pytest.mark.parametrize('exclude_self', [False, True])
check_knn(g, x, 4, 8, 4) def test_knn_cpu(algorithm, dist, exclude_self):
_test_knn_common(F.cpu(), algorithm, dist, exclude_self)
kg = dgl.nn.SegmentedKNNGraph(5)
with pytest.warns(DGLWarning):
g = kg(x, [3, 5], algorithm, dist)
check_knn(g, x, 0, 3, 3)
check_knn(g, x, 3, 8, 3)
# check k == 0 @pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'bruteforce-sharemem'])
kg = dgl.nn.KNNGraph(0) @pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
with pytest.raises(DGLError): @pytest.mark.parametrize('exclude_self', [False, True])
g = kg(x, algorithm, dist) def test_knn_cuda(algorithm, dist, exclude_self):
kg = dgl.nn.SegmentedKNNGraph(0) if not th.cuda.is_available():
with pytest.raises(DGLError): return
g = kg(x, [3, 5], algorithm, dist) _test_knn_common(F.cuda(), algorithm, dist, exclude_self)
# check empty
x_empty = th.tensor([])
kg = dgl.nn.KNNGraph(3)
with pytest.raises(DGLError):
g = kg(x_empty, algorithm, dist)
kg = dgl.nn.SegmentedKNNGraph(3)
with pytest.raises(DGLError):
g = kg(x_empty, [3, 5], algorithm, dist)
@parametrize_idtype @parametrize_idtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment