Commit 1e636721 authored by zhuwenwen's avatar zhuwenwen
Browse files

add convert_vertical_slash_indexes and convert_vertical_slash_indexes_mergehead

parent d70425b2
......@@ -179,8 +179,6 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
......@@ -205,7 +203,6 @@ void convert_vertical_slash_indexes_mergehead(
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, int64_t context_size,
int64_t block_size_M, int64_t block_size_N, bool causal);
#endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);
......
......@@ -234,7 +234,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
......@@ -257,7 +256,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
&convert_vertical_slash_indexes_mergehead);
#endif
// Activation ops
// Activation function used in SwiGLU.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment