importtorchdeffast_topk(values,topk,dim):iftopk==1:# Use max along the specified dimension to get both value and indexreturntorch.max(values,dim=dim,keepdim=True)else:# Use topk for efficiency with larger k values# TODO: implement faster cuda kernels for large vocab sizesreturntorch.topk(values,topk,dim=dim)