Commit 303e889a authored by Alexander Liao's avatar Alexander Liao
Browse files

fixed flake8 errors

parent 9016cdbf
from typing import Optional from typing import Optional
import torch import torch
import scipy.spatial
@torch.jit.script @torch.jit.script
def sample(col: torch.Tensor, count: int) -> torch.Tensor: def sample(col: torch.Tensor, count: int) -> torch.Tensor:
...@@ -87,7 +85,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -87,7 +85,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
assert x.size(0) == batch_x.size(0) assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0) assert y.size(0) == batch_y.size(0)
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1) x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1) y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
...@@ -104,6 +101,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, ...@@ -104,6 +101,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
return torch.stack([row[mask], col[mask]], dim=0) return torch.stack([row[mask], col[mask]], dim=0)
""" """
def radius_graph(x: torch.Tensor, r: float, def radius_graph(x: torch.Tensor, r: float,
batch: Optional[torch.Tensor] = None, loop: bool = False, batch: Optional[torch.Tensor] = None, loop: bool = False,
max_num_neighbors: int = 32, max_num_neighbors: int = 32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment