ops.h 11.2 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
#include "core/scalar_type.hpp"

8
9
void paged_attention_v1(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
10
11
12
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
13
14
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
15
16
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
17

18
19
20
void paged_attention_v2(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
21
22
23
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
24
25
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
26
27
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
28
29

void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
30
              double epsilon);
31
32

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
33
                        torch::Tensor& weight, double epsilon);
34
35

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
36
                      torch::Tensor& key, int64_t head_size,
37
38
39
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
40
                              torch::Tensor& key, int64_t head_size,
41
                              torch::Tensor& cos_sin_cache, bool is_neox,
42
                              int64_t rot_dim,
43
44
45
46
47
48
49
50
51
52
53
                              torch::Tensor& cos_sin_cache_offsets);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
54

55
56
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

57
58
59
60
61
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
                  torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
                  torch::Tensor& input_positions, torch::Tensor& seq_lens,
                  torch::Tensor& slot_mapping, torch::Tensor& block_tables);

62
#ifndef USE_ROCM
63
64
65
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
66
                        const std::vector<int64_t>& codebook_partition_sizes,
67
68
                        const std::optional<torch::Tensor>& bias);

69
70
71
torch::Tensor aqlm_dequant(
    const torch::Tensor& codes, const torch::Tensor& codebooks,
    const std::vector<int64_t>& codebook_partition_sizes);
72
73
74

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
75
                       int64_t split_k_iters);
76
77
78

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
79
80
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
81
82
83
84
85

torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                          torch::Tensor& b_scales, torch::Tensor& workspace,
                          int64_t size_m, int64_t size_n, int64_t size_k);

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
namespace machete {

std::vector<std::string> supported_schedules(
    vllm::ScalarTypeTorchPtr const& btype);

torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
                   vllm::ScalarTypeTorchPtr const& btype,
                   c10::optional<torch::Tensor> const& scales,
                   c10::optional<torch::Tensor> const& zeros,
                   c10::optional<int64_t> group_size,
                   c10::optional<torch::Tensor> const& C,
                   c10::optional<double> alpha, c10::optional<double> beta,
                   c10::optional<std::string> schedule);

torch::Tensor prepack_B(torch::Tensor const& B,
                        vllm::ScalarTypeTorchPtr const& btype);

};  // namespace machete

105
106
107
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                  torch::Tensor& b_meta,
                                  torch::Tensor& b_scales,
108
109
                                  torch::Tensor& workspace,
                                  vllm::ScalarTypeTorchPtr const& b_q_type,
110
111
112
113
                                  int64_t size_m, int64_t size_n,
                                  int64_t size_k);

torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
114
115
                               torch::Tensor& b_scales, torch::Tensor& b_zeros,
                               torch::Tensor& g_idx, torch::Tensor& perm,
116
117
                               torch::Tensor& workspace,
                               vllm::ScalarTypeTorchPtr const& b_q_type,
118
                               int64_t size_m, int64_t size_n, int64_t size_k,
119
120
                               bool is_k_full, bool has_zp,
                               bool use_fp32_reduce);
121
122
123
124
125

torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
                                 int64_t size_k, int64_t size_n,
                                 int64_t num_bits);

126
127
128
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
                                int64_t size_n, int64_t num_bits);

129
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
130
131
                              int64_t n);

132
133
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
                                  int64_t type, int64_t row);
134

135
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
136
137
                              int64_t row);

138
139
140
141
142
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                              torch::Tensor& b_scales, torch::Tensor& workspace,
                              int64_t num_bits, int64_t size_m, int64_t size_n,
                              int64_t size_k);

143
144
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

145
146
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
147
148
                       torch::Tensor const& b_scales,
                       c10::optional<torch::Tensor> const& bias);
149

150
151
152
153
154
155
156
157
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
                           c10::optional<torch::Tensor> const& azp,
                           c10::optional<torch::Tensor> const& bias);

158
159
160
161
162
163
164
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
                              torch::Tensor const& b_q_weight,
                              torch::Tensor const& s_tok,
                              torch::Tensor const& s_ch,
                              torch::Tensor const& s_group,
                              torch::Tensor& workspace, int64_t size_m,
                              int64_t size_n, int64_t size_k);
165
#endif
166

167
168
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
                              torch::Tensor const& scale);
169

170
171
172
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
                               torch::Tensor& scales);

173
174
175
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
176
                        bool use_exllama, int64_t bit);
177

178
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
179

180
181
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
                             torch::Tensor const& scale);
182

183
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
184
185
                              torch::Tensor& scale);

186
187
188
void dynamic_per_token_scaled_fp8_quant(
    torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
    c10::optional<torch::Tensor> const& scale_ub);
189

190
191
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
192
193
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
std::vector<torch::Tensor> selective_scan_fwd(
    const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
    const torch::Tensor& B, const torch::Tensor& C,
    const c10::optional<torch::Tensor>& D_,
    const c10::optional<torch::Tensor>& z_,
    const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
    const c10::optional<torch::Tensor>& index_,
    const c10::optional<torch::Tensor>& x);

at::Tensor causal_conv1d_update(const at::Tensor& x,
                                const at::Tensor& conv_state,
                                const at::Tensor& weight,
                                const c10::optional<at::Tensor>& bias_,
                                bool silu_activation);

at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
                             const c10::optional<at::Tensor>& bias_,
                             const c10::optional<at::Tensor>& seq_idx_,
                             const c10::optional<at::Tensor>& initial_states_,
                             const c10::optional<at::Tensor>& final_states_out_,
                             bool silu_activation);

217
#ifndef USE_ROCM
218
using fptr_t = int64_t;
219
220
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
                      const std::vector<std::string>& handles,
221
                      const std::vector<int64_t>& offsets, int64_t rank,
222
                      bool full_nvlink);
223
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
224
                      bool full_nvlink);
225
226
227
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
                      torch::Tensor& out);
228
void dispose(fptr_t _fa);
229
int64_t meta_size();
230
231
232
void register_buffer(fptr_t _fa, torch::Tensor& t,
                     const std::vector<std::string>& handles,
                     const std::vector<int64_t>& offsets);
233
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
234
235
236
    fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
                            const std::vector<std::vector<int64_t>>& offsets);
237
#endif