"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "ad228bb87c73abea77ca753c8efa347142b8dd2d"
Commit 21208fce authored by rusty1s's avatar rusty1s
Browse files

nearest cpu

parent 07c92be4
...@@ -4,12 +4,9 @@ import pytest ...@@ -4,12 +4,9 @@ import pytest
import torch import torch
from torch_cluster import nearest from torch_cluster import nearest
from .utils import tensor, grad_dtypes from .utils import grad_dtypes, devices, tensor
devices = [torch.device('cuda')]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_nearest(dtype, device): def test_nearest(dtype, device):
x = tensor([ x = tensor([
......
import torch import torch
import scipy.cluster
if torch.cuda.is_available(): if torch.cuda.is_available():
import nearest_cuda import nearest_cuda
...@@ -35,14 +36,24 @@ def nearest(x, y, batch_x=None, batch_y=None): ...@@ -35,14 +36,24 @@ def nearest(x, y, batch_x=None, batch_y=None):
x = x.view(-1, 1) if x.dim() == 1 else x x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y y = y.view(-1, 1) if y.dim() == 1 else y
assert x.is_cuda
assert x.dim() == 2 and batch_x.dim() == 1 assert x.dim() == 2 and batch_x.dim() == 1
assert y.dim() == 2 and batch_y.dim() == 1 assert y.dim() == 2 and batch_y.dim() == 1
assert x.size(1) == y.size(1) assert x.size(1) == y.size(1)
assert x.size(0) == batch_x.size(0) assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0) assert y.size(0) == batch_y.size(0)
op = nearest_cuda.nearest if x.is_cuda else None if x.is_cuda:
out = op(x, y, batch_x, batch_y) return nearest_cuda.nearest(x, y, batch_x, batch_y)
return out # Rescale x and y.
min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy
max_xy = max(x.max().item(), y.max().item())
x, y, = x / max_xy, y / max_xy
# Concat batch/features to ensure no cross-links between examples exist.
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)
return torch.from_numpy(scipy.cluster.vq.vq(x, y)[0]).to(torch.long)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment