kernels.cpp 577 Bytes
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include "torch/extension.h"
#include <vector>

std::vector<at::Tensor> convert_vertical_slash_indexes(
    torch::Tensor seqlens,          // [BATCH, ]
    torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
    torch::Tensor slash_indexes,    // [BATCH, N_HEADS, NNZ_S]
    int context_size, int block_size_M, int block_size_N);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("convert_vertical_slash_indexes", &convert_vertical_slash_indexes,
        "dynamic sparse index function");
}