"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "9a7235faf2835d424c4587e703024248e6b9f465"
Commit c5045e2f authored by rusty1s's avatar rusty1s
Browse files

fix numerical issues in nn

parent 24595a8f
...@@ -79,16 +79,18 @@ def nearest(x: torch.Tensor, y: torch.Tensor, ...@@ -79,16 +79,18 @@ def nearest(x: torch.Tensor, y: torch.Tensor,
assert y.size(0) == batch_y.size(0) assert y.size(0) == batch_y.size(0)
# Translate and rescale x and y to [0, 1]. # Translate and rescale x and y to [0, 1].
min_xy = min(x.min().item(), y.min().item()) if batch_x is not None and batch_y is not None:
x, y = x - min_xy, y - min_xy min_xy = min(x.min().item(), y.min().item())
x, y = x - min_xy, y - min_xy
max_xy = max(x.max().item(), y.max().item())
x.div_(max_xy) max_xy = max(x.max().item(), y.max().item())
y.div_(max_xy) x.div_(max_xy)
y.div_(max_xy)
# Concat batch/features to ensure no cross-links between examples.
x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], -1) # Concat batch/features to ensure no cross-links between examples.
y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], -1) D = x.size(-1)
x = torch.cat([x, 2 * D * batch_x.view(-1, 1).to(x.dtype)], -1)
y = torch.cat([y, 2 * D * batch_y.view(-1, 1).to(y.dtype)], -1)
return torch.from_numpy( return torch.from_numpy(
scipy.cluster.vq.vq(x.detach().cpu(), scipy.cluster.vq.vq(x.detach().cpu(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment