ft_attention.h 798 Bytes
Newer Older
Casper's avatar
Casper committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#pragma once
#include <torch/extension.h>


torch::Tensor single_query_attention(const torch::Tensor q,
                                     const torch::Tensor k,
                                     const torch::Tensor v,
                                     torch::Tensor k_cache,
                                     torch::Tensor v_cache,
                                     c10::optional<const torch::Tensor> length_per_sample_,
                                     c10::optional<const torch::Tensor> alibi_slopes_,
                                     const int timestep,
                                     const int rotary_embedding_dim = 0,
                                     const float rotary_base = 10000.0f,
                                     const bool neox_rotary_style=true);