Unverified Commit 6beef85b authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

upd (#1113)

parent 41114b87
......@@ -110,6 +110,14 @@ def argsort(input, dim, descending):
return np.argsort(-input, axis=dim)
return np.argsort(input, axis=dim)
def topk(input, k, dim, descending=True):
topk_indices = argtopk(input, k, dim, descending)
return np.take_along_axis(input, topk_indices, axis=dim)
def argtopk(input, k, dim, descending=True):
sort_indces = argsort(input, dim, descending)
return slice_axis(sort_indces, dim, 0, k)
def exp(input):
return np.exp(input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment