Commit 1e636721 authored by zhuwenwen's avatar zhuwenwen
Browse files

add convert_vertical_slash_indexes and convert_vertical_slash_indexes_mergehead

parent d70425b2
...@@ -179,8 +179,6 @@ void merge_attn_states(torch::Tensor& output, ...@@ -179,8 +179,6 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse); const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes( void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
...@@ -205,7 +203,6 @@ void convert_vertical_slash_indexes_mergehead( ...@@ -205,7 +203,6 @@ void convert_vertical_slash_indexes_mergehead(
torch::Tensor vertical_indices_count, // [N_HEADS, ] torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, int64_t context_size, torch::Tensor slash_indices_count, int64_t context_size,
int64_t block_size_M, int64_t block_size_N, bool causal); int64_t block_size_M, int64_t block_size_N, bool causal);
#endif
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon); double epsilon);
......
...@@ -234,7 +234,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -234,7 +234,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def( ops.def(
"convert_vertical_slash_indexes(" "convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, " " Tensor! block_count, Tensor! block_offset, "
...@@ -257,7 +256,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -257,7 +256,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" bool causal) -> ()"); " bool causal) -> ()");
ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA,
&convert_vertical_slash_indexes_mergehead); &convert_vertical_slash_indexes_mergehead);
#endif
// Activation ops // Activation ops
// Activation function used in SwiGLU. // Activation function used in SwiGLU.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment