topk_plain.h 476 Bytes
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#pragma once
// SPDX-License-Identifier: MIT

#include "aiter_enum.h"
#include <torch/extension.h>

void topk_plain(torch::Tensor& values,
                torch::Tensor& topk_ids,
                torch::Tensor& topk_out,
                int topk,
                bool largest = true,
                torch::Tensor rowStarts = torch::Tensor(),
                torch::Tensor rowEnds = torch::Tensor(),
                int64_t stride0 = -1,
                int64_t stride1 = 1);