Unverified Commit 01bec4a3 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix #3623 (#3628)

parent b226fe01
...@@ -320,7 +320,10 @@ def reshape(input, shape): ...@@ -320,7 +320,10 @@ def reshape(input, shape):
def swapaxes(input, axis1, axis2): def swapaxes(input, axis1, axis2):
return tf.transpose(input, perm=[axis1, axis2]) ndim = input.ndim
t = list(range(ndim))
t[axis1], t[axis2] = axis2 % ndim, axis1 % ndim
return tf.transpose(input, perm=t)
def zeros(shape, dtype, ctx): def zeros(shape, dtype, ctx):
......
...@@ -254,7 +254,7 @@ def _knn_graph_blas(x, k, dist='euclidean'): ...@@ -254,7 +254,7 @@ def _knn_graph_blas(x, k, dist='euclidean'):
ctx = F.context(x) ctx = F.context(x)
dist = pairwise_squared_distance(x) dist = pairwise_squared_distance(x)
k_indices = F.argtopk(dist, k, 2, descending=False) k_indices = F.astype(F.argtopk(dist, k, 2, descending=False), F.int64)
# index offset for each sample # index offset for each sample
offset = F.arange(0, n_samples, ctx=ctx) * n_points offset = F.arange(0, n_samples, ctx=ctx) * n_points
offset = F.unsqueeze(offset, 1) offset = F.unsqueeze(offset, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment