Commit 75216b3c authored by rusty1s's avatar rusty1s
Browse files

better knn/radius readme description [ci skip]

parent b75a19e8
...@@ -158,6 +158,16 @@ tensor([0, 3]) ...@@ -158,6 +158,16 @@ tensor([0, 3])
Computes graph edges to the nearest *k* points. Computes graph edges to the nearest *k* points.
**Args:**
* **x** *(Tensor)*: Node feature matrix of shape `[N, F]`.
* **r** *(float)*: The radius.
* **batch** *(LongTensor, optional)*: Batch vector of shape `[N]`, which assigns each node to a specific example. `batch` needs to be sorted. (default: `None`)
* **loop** *(bool, optional)*: If `True`, the graph will contain self-loops. (default: `False`)
* **flow** *(string, optional)*: The flow direction when using in combination with message passing (`"source_to_target"` or `"target_to_source"`). (default: `"source_to_target"`)
* **cosine** *(boolean, optional)*: If `True`, will use the Cosine distance instead of Euclidean distance to find nearest neighbors. (default: `False`)
* **num_workers** *(int)*: Number of workers to use for computation. Has no effect in case `batch` is not `None`, or the input lies on the GPU. (default: `1`)
```python ```python
import torch import torch
from torch_cluster import knn_graph from torch_cluster import knn_graph
...@@ -177,6 +187,16 @@ tensor([[1, 2, 0, 3, 0, 3, 1, 2], ...@@ -177,6 +187,16 @@ tensor([[1, 2, 0, 3, 0, 3, 1, 2],
Computes graph edges to all points within a given distance. Computes graph edges to all points within a given distance.
**Args:**
* **x** *(Tensor)*: Node feature matrix of shape `[N, F]`.
* **r** *(float)*: The radius.
* **batch** *(LongTensor, optional)*: Batch vector of shape `[N]`, which assigns each node to a specific example. `batch` needs to be sorted. (default: `None`)
* **loop** *(bool, optional)*: If `True`, the graph will contain self-loops. (default: `False`)
* **max_num_neighbors** *(int, optional)*: The maximum number of neighbors to return for each element. (default: `32`)
* **flow** *(string, optional)*: The flow direction when using in combination with message passing (`"source_to_target"` or `"target_to_source"`). (default: `"source_to_target"`)
* **num_workers** *(int)*: Number of workers to use for computation. Has no effect in case `batch` is not `None`, or the input lies on the GPU. (default: `1`)
```python ```python
import torch import torch
from torch_cluster import radius_graph from torch_cluster import radius_graph
......
...@@ -92,11 +92,11 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, ...@@ -92,11 +92,11 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
(default: :obj:`None`) (default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`) self-loops. (default: :obj:`False`)
flow (string, optional): The flow direction when using in combination flow (string, optional): The flow direction when used in combination
with message passing (:obj:`"source_to_target"` or with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
cosine (boolean, optional): If :obj:`True`, will use the cosine cosine (boolean, optional): If :obj:`True`, will use the Cosine
distance instead of euclidean distance to find nearest neighbors. distance instead of Euclidean distance to find nearest neighbors.
(default: :obj:`False`) (default: :obj:`False`)
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies effect in case :obj:`batch` is not :obj:`None`, or the input lies
......
...@@ -91,8 +91,8 @@ def radius_graph(x: torch.Tensor, r: float, ...@@ -91,8 +91,8 @@ def radius_graph(x: torch.Tensor, r: float,
loop (bool, optional): If :obj:`True`, the graph will contain loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`) self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`. (default: :obj:`32`) return for each element. (default: :obj:`32`)
flow (string, optional): The flow direction when using in combination flow (string, optional): The flow direction when used in combination
with message passing (:obj:`"source_to_target"` or with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`) :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
num_workers (int): Number of workers to use for computation. Has no num_workers (int): Number of workers to use for computation. Has no
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment