Unverified Commit 684f66b7 authored by ndickson-nvidia's avatar ndickson-nvidia Committed by GitHub
Browse files

[Feature] Added exclude_self and output_batch to knn graph construction...


[Feature] Added exclude_self and output_batch to knn graph construction (Issues #4323 #4316) (#4389)

* * Added "exclude_self" and "output_batch" options to knn_graph and segmented_knn_graph
* Updated out-of-date comments on remove_edges and remove_self_loop, since they now preserve batch information

* * Changed defaults on new knn_graph and segmented_knn_graph function parameters, for compatibility; pytorch/test_geometry.py was failing

* * Added test to ensure dgl.remove_self_loop function correctly updates batch information

* * Added new knn_graph and segmented_knn_graph parameters to dgl.nn.KNNGraph and dgl.nn.SegmentedKNNGraph

* * Formatting

* * Oops, I missed the one in segmented_knn_graph when I fixed the similar thing in knn_graph

* * Fixed edge case handling when invalid k specified, since it still needs to be handled consistently for tests to pass
* Fixed context of batch info, since it must match the context of the input position data for remove_self_loop to succeed

* * Fixed batch info resulting from knn_graph when output_batch is true, for case of 3D input tensor, representing multiple segments

* * Added testing of new exclude_self and output_batch parameters on knn_graph and segmented_knn_graph, and their wrappers, KNNGraph and SegmentedKNNGraph, into the test_knn_cuda test

* * Added doc comments for new parameters

* * Added correct handling for uncommon case of k or more coincident points when excluding self edges in knn_graph and segmented_knn_graph
* Added test cases for more than k coincident points

* * Updated doc comments for output_batch parameters for clarity

* * Linter formatting fixes

* * Extracted out common function for test_knn_cpu and test_knn_cuda, to add the new test cases to test_knn_cpu

* * Rewording in doc comments

* * Removed output_batch parameter from knn_graph and segmented_knn_graph, in favour of always setting the batch information, except in knn_graph if x is a 2D tensor
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 1c9d2a03
......@@ -13,7 +13,7 @@ def pairwise_squared_distance(x):
class KNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of
point sets with the same number of points into a union of those graphs.
point sets with the same number of points into a batched union of those graphs.
The KNNGraph is implemented in the following steps:
......@@ -63,7 +63,8 @@ class KNNGraph(nn.Module):
self.k = k
#pylint: disable=invalid-name
def forward(self, x, algorithm='bruteforce-blas', dist='euclidean'):
def forward(self, x, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
r"""
Forward computation.
......@@ -113,18 +114,23 @@ class KNNGraph(nn.Module):
:math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
exclude_self : bool, optional
If True, the output graph will not contain self loop edges, and each node will not
be counted as one of its own k neighbors. If False, the output graph will contain
self loop edges, and a node will be counted as one of its own k neighbors.
Returns
-------
DGLGraph
A DGLGraph without features.
"""
return knn_graph(x, self.k, algorithm=algorithm, dist=dist)
return knn_graph(x, self.k, algorithm=algorithm, dist=dist,
exclude_self=exclude_self)
class SegmentedKNNGraph(nn.Module):
r"""Layer that transforms one point set into a graph, or a batch of
point sets with different number of points into a union of those graphs.
point sets with different number of points into a batched union of those graphs.
If a batch of point sets is provided, then the point :math:`j` in the point
set :math:`i` is mapped to graph node ID:
......@@ -171,7 +177,8 @@ class SegmentedKNNGraph(nn.Module):
self.k = k
#pylint: disable=invalid-name
def forward(self, x, segs, algorithm='bruteforce-blas', dist='euclidean'):
def forward(self, x, segs, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
r"""Forward computation.
Parameters
......@@ -222,14 +229,19 @@ class SegmentedKNNGraph(nn.Module):
:math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
exclude_self : bool, optional
If True, the output graph will not contain self loop edges, and each node will not
be counted as one of its own k neighbors. If False, the output graph will contain
self loop edges, and a node will be counted as one of its own k neighbors.
Returns
-------
DGLGraph
A DGLGraph without features.
A batched DGLGraph without features.
"""
return segmented_knn_graph(x, self.k, segs, algorithm=algorithm, dist=dist)
return segmented_knn_graph(x, self.k, segs, algorithm=algorithm, dist=dist,
exclude_self=exclude_self)
class RadiusGraph(nn.Module):
......
......@@ -41,6 +41,7 @@ from ..partition import partition_graph_with_halo
from ..partition import metis_partition
from .. import subgraph
from .. import function
from ..sampling.neighbor import sample_neighbors
# TO BE DEPRECATED
from .._deprecate.graph import DGLGraph as DGLGraphStale
......@@ -97,7 +98,8 @@ def pairwise_squared_distance(x):
return x2s + F.swapaxes(x2s, -1, -2) - 2 * x @ F.swapaxes(x, -1, -2)
#pylint: disable=invalid-name
def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN)
and return.
......@@ -110,8 +112,8 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
of each point are its k-nearest neighbors measured by the chosen distance.
If :attr:`x` is a 3D tensor, then each submatrix will be transformed
into a separate graph. DGL then composes the graphs into a large
graph of multiple connected components.
into a separate graph. DGL then composes the graphs into a large batched
graph of multiple (:math:`shape(x)[0]`) connected components.
See :doc:`the benchmark <../api/python/knn_benchmark>` for a complete benchmark result.
......@@ -164,6 +166,10 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
exclude_self : bool, optional
If True, the output graph will not contain self loop edges, and each node will not
be counted as one of its own k neighbors. If False, the output graph will contain
self loop edges, and a node will be counted as one of its own k neighbors.
Returns
-------
......@@ -205,26 +211,58 @@ def knn_graph(x, k, algorithm='bruteforce-blas', dist='euclidean'):
(tensor([0, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 6, 6, 7, 7]),
tensor([0, 1, 1, 2, 3, 0, 2, 3, 4, 5, 6, 7, 4, 6, 5, 7]))
"""
if exclude_self:
# add 1 to k, for the self edge, since it will be removed
k = k + 1
# check invalid k
if k <= 0:
raise DGLError("Invalid k value. expect k > 0, got k = {}".format(k))
# check empty point set
if F.shape(x)[0] == 0:
x_size = tuple(F.shape(x))
if x_size[0] == 0:
raise DGLError("Find empty point set")
d = F.ndim(x)
x_seg = x_size[0] * [x_size[1]] if d == 3 else [x_size[0]]
if algorithm == 'bruteforce-blas':
return _knn_graph_blas(x, k, dist=dist)
result = _knn_graph_blas(x, k, dist=dist)
else:
if F.ndim(x) == 3:
x_size = tuple(F.shape(x))
if d == 3:
x = F.reshape(x, (x_size[0] * x_size[1], x_size[2]))
x_seg = x_size[0] * [x_size[1]]
else:
x_seg = [F.shape(x)[0]]
out = knn(k, x, x_seg, algorithm=algorithm, dist=dist)
row, col = out[1], out[0]
return convert.graph((row, col))
result = convert.graph((row, col))
if d == 3:
# set batch information if x is 3D
num_nodes = F.tensor(x_seg, dtype=F.int64).to(F.context(x))
result.set_batch_num_nodes(num_nodes)
# if any segment is too small for k, all algorithms reduce k for all segments
clamped_k = min(k, np.min(x_seg))
result.set_batch_num_edges(clamped_k*num_nodes)
if exclude_self:
# remove_self_loop will update batch_num_edges as needed
result = remove_self_loop(result)
# If there were more than k(+1) coincident points, there may not have been self loops on
# all nodes, in which case there would still be one too many out edges on some nodes.
# However, if every node had a self edge, the common case, every node would still have the
# same degree as each other, so we can check that condition easily.
# The -1 is for the self edge removal.
clamped_k = min(k, np.min(x_seg)) - 1
if result.num_edges() != clamped_k*result.num_nodes():
# edges on any nodes with too high degree should all be length zero,
# so pick an arbitrary one to remove from each such node
degrees = result.in_degrees()
node_indices = F.nonzero_1d(degrees > clamped_k)
edges_to_remove_graph = sample_neighbors(result, node_indices, 1, edge_dir='in')
edge_ids = edges_to_remove_graph.edata[EID]
result = remove_edges(result, edge_ids)
return result
def _knn_graph_blas(x, k, dist='euclidean'):
r"""Construct a graph from a set of points according to k-nearest-neighbor (KNN).
......@@ -279,7 +317,8 @@ def _knn_graph_blas(x, k, dist='euclidean'):
return convert.graph((F.reshape(src, (-1,)), F.reshape(dst, (-1,))))
#pylint: disable=invalid-name
def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean'):
def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean',
exclude_self=False):
r"""Construct multiple graphs from multiple sets of points according to
k-nearest-neighbor (KNN) and return.
......@@ -290,7 +329,7 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
function constructs a KNN graph for each point set, where the predecessors
of each point are its k-nearest neighbors measured by the Euclidean distance.
DGL then composes all KNN graphs
into a graph with multiple connected components.
into a batched graph with multiple (:math:`len(segs)`) connected components.
Parameters
----------
......@@ -339,11 +378,15 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
* 'euclidean': Use Euclidean distance (L2 norm) :math:`\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}`.
* 'cosine': Use cosine distance.
(default: 'euclidean')
exclude_self : bool, optional
If True, the output graph will not contain self loop edges, and each node will not
be counted as one of its own k neighbors. If False, the output graph will contain
self loop edges, and a node will be counted as one of its own k neighbors.
Returns
-------
DGLGraph
The graph. The node IDs are in the same order as :attr:`x`.
The batched graph. The node IDs are in the same order as :attr:`x`.
Examples
--------
......@@ -372,6 +415,10 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
(tensor([0, 0, 1, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6]),
tensor([0, 1, 0, 1, 2, 2, 3, 5, 4, 6, 3, 5, 4, 6]))
"""
if exclude_self:
# add 1 to k, for the self edge, since it will be removed
k = k + 1
# check invalid k
if k <= 0:
raise DGLError("Invalid k value. expect k > 0, got k = {}".format(k))
......@@ -381,11 +428,38 @@ def segmented_knn_graph(x, k, segs, algorithm='bruteforce-blas', dist='euclidean
raise DGLError("Find empty point set")
if algorithm == 'bruteforce-blas':
return _segmented_knn_graph_blas(x, k, segs, dist=dist)
result = _segmented_knn_graph_blas(x, k, segs, dist=dist)
else:
out = knn(k, x, segs, algorithm=algorithm, dist=dist)
row, col = out[1], out[0]
return convert.graph((row, col))
result = convert.graph((row, col))
num_nodes = F.tensor(segs, dtype=F.int64).to(F.context(x))
result.set_batch_num_nodes(num_nodes)
# if any segment is too small for k, all algorithms reduce k for all segments
clamped_k = min(k, np.min(segs))
result.set_batch_num_edges(clamped_k*num_nodes)
if exclude_self:
# remove_self_loop will update batch_num_edges as needed
result = remove_self_loop(result)
# If there were more than k(+1) coincident points, there may not have been self loops on
# all nodes, in which case there would still be one too many out edges on some nodes.
# However, if every node had a self edge, the common case, every node would still have the
# same degree as each other, so we can check that condition easily.
# The -1 is for the self edge removal.
clamped_k = min(k, np.min(segs)) - 1
if result.num_edges() != clamped_k*result.num_nodes():
# edges on any nodes with too high degree should all be length zero,
# so pick an arbitrary one to remove from each such node
degrees = result.in_degrees()
node_indices = F.nonzero_1d(degrees > clamped_k)
edges_to_remove_graph = sample_neighbors(result, node_indices, 1, edge_dir='in')
edge_ids = edges_to_remove_graph.edata[EID]
result = remove_edges(result, edge_ids)
return result
def _segmented_knn_graph_blas(x, k, segs, dist='euclidean'):
r"""Construct multiple graphs from multiple sets of points according to
......@@ -1638,11 +1712,7 @@ def remove_edges(g, eids, etype=None, store_ids=False):
Notes
-----
This function discards the batch information. Please use
:func:`dgl.DGLGraph.set_batch_num_nodes`
and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph
to maintain the information.
This function preserves the batch information.
Examples
--------
......@@ -1910,10 +1980,7 @@ def remove_self_loop(g, etype=None):
If a node has multiple self-loops, remove them all. Do nothing for nodes without
self-loops.
This function discards the batch information. Please use
:func:`dgl.DGLGraph.set_batch_num_nodes`
and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph
to maintain the information.
This function preserves the batch information.
Examples
---------
......
......@@ -1797,6 +1797,16 @@ def test_remove_selfloop(idtype):
raise_error = True
assert raise_error
# batch information
g = dgl.graph(([0, 0, 0, 1, 3, 3, 4], [1, 0, 0, 2, 3, 4, 4]), idtype=idtype, device=F.ctx())
g.set_batch_num_nodes(F.tensor([3, 2], dtype=F.int64))
g.set_batch_num_edges(F.tensor([4, 3], dtype=F.int64))
g = dgl.remove_self_loop(g)
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 3
assert F.array_equal(g.batch_num_nodes(), F.tensor([3, 2], dtype=F.int64))
assert F.array_equal(g.batch_num_edges(), F.tensor([2, 1], dtype=F.int64))
@parametrize_idtype
def test_reorder_graph(idtype):
......
......@@ -36,11 +36,8 @@ def test_fps_start_idx():
res = farthest_point_sampler(x, sample_points, start_idx=0)
assert th.any(res[:, 0] == 0)
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'kd-tree'])
@pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
def test_knn_cpu(algorithm, dist):
x = th.randn(8, 3).to(F.cpu())
def _test_knn_common(device, algorithm, dist, exclude_self):
x = th.randn(8, 3).to(device)
kg = dgl.nn.KNNGraph(3)
if dist == 'euclidean':
d = th.cdist(x, x).to(F.cpu())
......@@ -49,136 +46,126 @@ def test_knn_cpu(algorithm, dist):
tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True)))
d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu())
def check_knn(g, x, start, end, k):
def check_knn(g, x, start, end, k, exclude_self, check_indices=True):
assert g.device == x.device
g = g.to(F.cpu())
for v in range(start, end):
src, _ = g.in_edges(v)
src = set(src.numpy())
i = v - start
src_ans = set(th.topk(d[start:end, start:end][i], k, largest=False)[1].numpy() + start)
assert src == src_ans
assert len(src) == k
if check_indices:
i = v - start
src_ans = set(th.topk(d[start:end, start:end][i], k + (1 if exclude_self else 0), largest=False)[1].numpy() + start)
if exclude_self:
# remove self
src_ans.remove(v)
assert src == src_ans
def check_batch(g, k, expected_batch_info):
assert F.array_equal(g.batch_num_nodes(), F.tensor(expected_batch_info))
assert F.array_equal(g.batch_num_edges(), k*F.tensor(expected_batch_info))
# check knn with 2d input
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 3)
g = kg(x, algorithm, dist, exclude_self)
check_knn(g, x, 0, 8, 3, exclude_self)
check_batch(g, 3, [8])
# check knn with 3d input
g = kg(x.view(2, 4, 3), algorithm, dist)
check_knn(g, x, 0, 4, 3)
check_knn(g, x, 4, 8, 3)
g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)
check_knn(g, x, 0, 4, 3, exclude_self)
check_knn(g, x, 4, 8, 3, exclude_self)
check_batch(g, 3, [4, 4])
# check segmented knn
kg = dgl.nn.SegmentedKNNGraph(3)
g = kg(x, [3, 5], algorithm, dist)
check_knn(g, x, 0, 3, 3)
check_knn(g, x, 3, 8, 3)
# there are only 2 edges per node possible when exclude_self with 3 nodes in the segment
# and this test case isn't supposed to warn, so limit it when exclude_self is True
adjusted_k = 3 - (1 if exclude_self else 0)
kg = dgl.nn.SegmentedKNNGraph(adjusted_k)
g = kg(x, [3, 5], algorithm, dist, exclude_self)
check_knn(g, x, 0, 3, adjusted_k, exclude_self)
check_knn(g, x, 3, 8, adjusted_k, exclude_self)
check_batch(g, adjusted_k, [3, 5])
# check k > num_points
kg = dgl.nn.KNNGraph(10)
with pytest.warns(DGLWarning):
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 8)
g = kg(x, algorithm, dist, exclude_self)
# there are only 7 edges per node possible when exclude_self with 8 nodes total
adjusted_k = 8 - (1 if exclude_self else 0)
check_knn(g, x, 0, 8, adjusted_k, exclude_self)
check_batch(g, adjusted_k, [8])
with pytest.warns(DGLWarning):
g = kg(x.view(2, 4, 3), algorithm, dist)
check_knn(g, x, 0, 4, 4)
check_knn(g, x, 4, 8, 4)
g = kg(x.view(2, 4, 3), algorithm, dist, exclude_self)
# there are only 3 edges per node possible when exclude_self with 4 nodes per segment
adjusted_k = 4 - (1 if exclude_self else 0)
check_knn(g, x, 0, 4, adjusted_k, exclude_self)
check_knn(g, x, 4, 8, adjusted_k, exclude_self)
check_batch(g, adjusted_k, [4, 4])
kg = dgl.nn.SegmentedKNNGraph(5)
with pytest.warns(DGLWarning):
g = kg(x, [3, 5], algorithm, dist)
check_knn(g, x, 0, 3, 3)
check_knn(g, x, 3, 8, 3)
g = kg(x, [3, 5], algorithm, dist, exclude_self)
# there are only 2 edges per node possible when exclude_self in the segment with
# only 3 nodes, and the current implementation reduces k for all segments
# in that case
adjusted_k = 3 - (1 if exclude_self else 0)
check_knn(g, x, 0, 3, adjusted_k, exclude_self)
check_knn(g, x, 3, 8, adjusted_k, exclude_self)
check_batch(g, adjusted_k, [3, 5])
# check k == 0
kg = dgl.nn.KNNGraph(0)
# that's valid for exclude_self, but -1 is not, so check -1 instead for exclude_self
adjusted_k = 0 - (1 if exclude_self else 0)
kg = dgl.nn.KNNGraph(adjusted_k)
with pytest.raises(DGLError):
g = kg(x, algorithm, dist)
kg = dgl.nn.SegmentedKNNGraph(0)
g = kg(x, algorithm, dist, exclude_self)
kg = dgl.nn.SegmentedKNNGraph(adjusted_k)
with pytest.raises(DGLError):
g = kg(x, [3, 5], algorithm, dist)
g = kg(x, [3, 5], algorithm, dist, exclude_self)
# check empty
x_empty = th.tensor([])
kg = dgl.nn.KNNGraph(3)
with pytest.raises(DGLError):
g = kg(x_empty, algorithm, dist)
g = kg(x_empty, algorithm, dist, exclude_self)
kg = dgl.nn.SegmentedKNNGraph(3)
with pytest.raises(DGLError):
g = kg(x_empty, [3, 5], algorithm, dist)
g = kg(x_empty, [3, 5], algorithm, dist, exclude_self)
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'bruteforce-sharemem'])
@pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
def test_knn_cuda(algorithm, dist):
if not th.cuda.is_available():
return
x = th.randn(8, 3).to(F.cuda())
# check all coincident points
x = th.zeros((20, 3)).to(device)
kg = dgl.nn.KNNGraph(3)
if dist == 'euclidean':
d = th.cdist(x, x).to(F.cpu())
else:
x = x + th.randn(1).item()
tmp_x = x / (1e-5 + F.sqrt(F.sum(x * x, dim=1, keepdims=True)))
d = 1 - F.matmul(tmp_x, tmp_x.T).to(F.cpu())
def check_knn(g, x, start, end, k):
assert g.device == x.device
g = g.to(F.cpu())
for v in range(start, end):
src, _ = g.in_edges(v)
src = set(src.numpy())
i = v - start
src_ans = set(th.topk(d[start:end, start:end][i], k, largest=False)[1].numpy() + start)
assert src == src_ans
# check knn with 2d input
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 3)
# check knn with 3d input
g = kg(x.view(2, 4, 3), algorithm, dist)
check_knn(g, x, 0, 4, 3)
check_knn(g, x, 4, 8, 3)
g = kg(x, algorithm, dist, exclude_self)
# different algorithms may break the tie differently, so don't check the indices
check_knn(g, x, 0, 20, 3, exclude_self, False)
check_batch(g, 3, [20])
# check segmented knn
# check all coincident points
kg = dgl.nn.SegmentedKNNGraph(3)
g = kg(x, [3, 5], algorithm, dist)
check_knn(g, x, 0, 3, 3)
check_knn(g, x, 3, 8, 3)
# check k > num_points
kg = dgl.nn.KNNGraph(10)
with pytest.warns(DGLWarning):
g = kg(x, algorithm, dist)
check_knn(g, x, 0, 8, 8)
g = kg(x, [4, 7, 5, 4], algorithm, dist, exclude_self)
# different algorithms may break the tie differently, so don't check the indices
check_knn(g, x, 0, 4, 3, exclude_self, False)
check_knn(g, x, 4, 11, 3, exclude_self, False)
check_knn(g, x, 11, 16, 3, exclude_self, False)
check_knn(g, x, 16, 20, 3, exclude_self, False)
check_batch(g, 3, [4, 7, 5, 4])
with pytest.warns(DGLWarning):
g = kg(x.view(2, 4, 3), algorithm, dist)
check_knn(g, x, 0, 4, 4)
check_knn(g, x, 4, 8, 4)
kg = dgl.nn.SegmentedKNNGraph(5)
with pytest.warns(DGLWarning):
g = kg(x, [3, 5], algorithm, dist)
check_knn(g, x, 0, 3, 3)
check_knn(g, x, 3, 8, 3)
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'kd-tree'])
@pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
@pytest.mark.parametrize('exclude_self', [False, True])
def test_knn_cpu(algorithm, dist, exclude_self):
_test_knn_common(F.cpu(), algorithm, dist, exclude_self)
# check k == 0
kg = dgl.nn.KNNGraph(0)
with pytest.raises(DGLError):
g = kg(x, algorithm, dist)
kg = dgl.nn.SegmentedKNNGraph(0)
with pytest.raises(DGLError):
g = kg(x, [3, 5], algorithm, dist)
# check empty
x_empty = th.tensor([])
kg = dgl.nn.KNNGraph(3)
with pytest.raises(DGLError):
g = kg(x_empty, algorithm, dist)
kg = dgl.nn.SegmentedKNNGraph(3)
with pytest.raises(DGLError):
g = kg(x_empty, [3, 5], algorithm, dist)
@pytest.mark.parametrize('algorithm', ['bruteforce-blas', 'bruteforce', 'bruteforce-sharemem'])
@pytest.mark.parametrize('dist', ['euclidean', 'cosine'])
@pytest.mark.parametrize('exclude_self', [False, True])
def test_knn_cuda(algorithm, dist, exclude_self):
if not th.cuda.is_available():
return
_test_knn_common(F.cuda(), algorithm, dist, exclude_self)
@parametrize_idtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment