Unverified Commit 01bec4a3 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix #3623 (#3628)

parent b226fe01
......@@ -320,7 +320,10 @@ def reshape(input, shape):
def swapaxes(input, axis1, axis2):
return tf.transpose(input, perm=[axis1, axis2])
ndim = input.ndim
t = list(range(ndim))
t[axis1], t[axis2] = axis2 % ndim, axis1 % ndim
return tf.transpose(input, perm=t)
def zeros(shape, dtype, ctx):
......
......@@ -254,7 +254,7 @@ def _knn_graph_blas(x, k, dist='euclidean'):
ctx = F.context(x)
dist = pairwise_squared_distance(x)
k_indices = F.argtopk(dist, k, 2, descending=False)
k_indices = F.astype(F.argtopk(dist, k, 2, descending=False), F.int64)
# index offset for each sample
offset = F.arange(0, n_samples, ctx=ctx) * n_points
offset = F.unsqueeze(offset, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment