Commit 303e889a authored by Alexander Liao's avatar Alexander Liao
Browse files

fixed flake8 errors

parent 9016cdbf
from typing import Optional
import torch
import scipy.spatial
@torch.jit.script
def sample(col: torch.Tensor, count: int) -> torch.Tensor:
......@@ -87,7 +85,6 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
x = torch.cat([x, 2 * r * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * r * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
......@@ -104,6 +101,7 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
return torch.stack([row[mask], col[mask]], dim=0)
"""
def radius_graph(x: torch.Tensor, r: float,
batch: Optional[torch.Tensor] = None, loop: bool = False,
max_num_neighbors: int = 32,
......@@ -144,7 +142,7 @@ def radius_graph(x: torch.Tensor, r: float,
row, col = (col, row) if flow == 'source_to_target' else (row, col)
else:
row, col = (col, row) if flow == 'target_to_source' else (row, col)
if not loop:
mask = row != col
row, col = row[mask], col[mask]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment