ops.h 9.26 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
#include "core/scalar_type.hpp"

8
9
void paged_attention_v1(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
10
11
12
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
13
14
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
15
16
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
17

18
19
20
void paged_attention_v2(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
21
22
23
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
24
25
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
26
27
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
28
29

void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
30
              double epsilon);
31
32

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
33
                        torch::Tensor& weight, double epsilon);
34
35

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
36
                      torch::Tensor& key, int64_t head_size,
37
38
39
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
40
                              torch::Tensor& key, int64_t head_size,
41
                              torch::Tensor& cos_sin_cache, bool is_neox,
42
                              int64_t rot_dim,
43
44
45
46
47
48
49
50
51
52
53
                              torch::Tensor& cos_sin_cache_offsets);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
54

55
56
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
                            int64_t block_size, torch::Tensor& input_tokens,
                            torch::Tensor& sampled_token_ids,
                            torch::Tensor& input_positions,
                            torch::Tensor& seq_lens,
                            torch::Tensor& slot_mapping,
                            torch::Tensor& block_tables);

void advance_step_flashinfer(
    int64_t num_seqs, int64_t num_queries, int64_t block_size,
    torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
    torch::Tensor& input_positions, torch::Tensor& seq_lens,
    torch::Tensor& slot_mapping, torch::Tensor& block_tables,
    torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
    torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
72

73
#ifndef USE_ROCM
74
75
76
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
77
                        const std::vector<int64_t>& codebook_partition_sizes,
78
79
                        const std::optional<torch::Tensor>& bias);

80
81
82
torch::Tensor aqlm_dequant(
    const torch::Tensor& codes, const torch::Tensor& codebooks,
    const std::vector<int64_t>& codebook_partition_sizes);
83
84
85

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
86
                       int64_t split_k_iters);
87
88
89

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
90
91
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
92

93
94
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);

95
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
96
97
                              int64_t n);

98
99
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
                                  int64_t type, int64_t row);
100

101
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
102
103
                              int64_t row);

104
105
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

106
107
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
108
109
                       torch::Tensor const& b_scales,
                       c10::optional<torch::Tensor> const& bias);
110

111
112
113
114
115
116
117
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
                           c10::optional<torch::Tensor> const& azp,
                           c10::optional<torch::Tensor> const& bias);
118
#endif
119

120
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
121
122
                              torch::Tensor const& scale,
                              c10::optional<torch::Tensor> const& azp);
123

124
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
125
126
                               torch::Tensor& scales,
                               c10::optional<torch::Tensor> const& azp);
127

128
129
130
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
131
                        bool use_exllama, int64_t bit);
132

133
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
134

135
136
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
                             torch::Tensor const& scale);
137

138
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
139
140
                              torch::Tensor& scale);

141
142
143
void dynamic_per_token_scaled_fp8_quant(
    torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
    c10::optional<torch::Tensor> const& scale_ub);
144

145
146
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
147
148
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);
149

150
151
152
153
154
155
156
157
158
159
160
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
                        const torch::Tensor& A, const torch::Tensor& B,
                        const torch::Tensor& C,
                        const c10::optional<torch::Tensor>& D_,
                        const c10::optional<torch::Tensor>& z_,
                        const c10::optional<torch::Tensor>& delta_bias_,
                        bool delta_softplus,
                        const c10::optional<torch::Tensor>& query_start_loc,
                        const c10::optional<torch::Tensor>& cache_indices,
                        const c10::optional<torch::Tensor>& has_initial_state,
                        const torch::Tensor& ssm_states);
161

162
163
at::Tensor causal_conv1d_update(
    const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
164
165
166
    const c10::optional<at::Tensor>& bias_, bool silu_activation,
    const c10::optional<at::Tensor>& cache_seqlens_,
    const c10::optional<at::Tensor>& conv_state_indices_);
167
168
169

at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
                             const c10::optional<at::Tensor>& bias_,
170
171
172
173
                             const c10::optional<at::Tensor>& conv_states,
                             const c10::optional<at::Tensor>& query_start_loc,
                             const c10::optional<at::Tensor>& cache_indices,
                             const c10::optional<at::Tensor>& has_initial_state,
174
175
                             bool silu_activation);

176
#ifndef USE_ROCM
177
using fptr_t = int64_t;
178
179
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
                      const std::vector<std::string>& handles,
180
                      const std::vector<int64_t>& offsets, int64_t rank,
181
182
183
184
                      bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
                      torch::Tensor& out);
185
void dispose(fptr_t _fa);
186
int64_t meta_size();
187
188
189
void register_buffer(fptr_t _fa, torch::Tensor& t,
                     const std::vector<std::string>& handles,
                     const std::vector<int64_t>& offsets);
190
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
191
192
193
    fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
                            const std::vector<std::vector<int64_t>>& offsets);
194
#endif