ops.h 19.7 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
#include "core/scalar_type.hpp"

8
void paged_attention_v1(
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v2(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v1_opt(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v2_opt(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v1_opt_tc(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v2_opt_tc(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);


// paged_attention with attn_masks
void paged_attention_v1_with_mask(
74
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
75
76
77
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
78
79
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
80
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
81
82
83
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
84

85
void paged_attention_v2_with_mask(
86
87
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
88
89
90
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
91
92
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
93
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
94
95
96
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
97

98
void paged_attention_v1_opt_with_mask(
zhuwenwen's avatar
zhuwenwen committed
99
100
101
102
103
104
105
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
106
107
108
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
zhuwenwen's avatar
zhuwenwen committed
109

110
void paged_attention_v2_opt_with_mask(
zhuwenwen's avatar
zhuwenwen committed
111
112
113
114
115
116
117
118
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
119
120
121
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
zhuwenwen's avatar
zhuwenwen committed
122

123
void paged_attention_v1_opt_tc_with_mask(
124
125
126
127
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
zhuwenwen's avatar
zhuwenwen committed
128
129
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
130
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
131
132
133
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
134

135
void paged_attention_v2_opt_tc_with_mask(
136
137
138
139
140
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
zhuwenwen's avatar
zhuwenwen committed
141
142
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
143
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
144
145
146
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
147

148

149
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
150
              double epsilon);
151
152

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
153
                        torch::Tensor& weight, double epsilon);
154

zhuwenwen's avatar
zhuwenwen committed
155
156
157
158
159
160
void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
              double epsilon);

void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
                        torch::Tensor& weight, double epsilon);

161
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
162
                      torch::Tensor& key, int64_t head_size,
163
164
165
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
166
                              torch::Tensor& key, int64_t head_size,
167
                              torch::Tensor& cos_sin_cache, bool is_neox,
168
                              int64_t rot_dim,
169
                              torch::Tensor& cos_sin_cache_offsets);
huangwb's avatar
huangwb committed
170
171
172
173
174
175
176
void rotary_embedding_tgi(
  torch::Tensor& query,
  torch::Tensor& key,
  int64_t head_size,
  torch::Tensor& cos_cache,
  torch::Tensor& sin_cache,
  bool is_neox);
177
178
179
180
181
182
183

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

zhuwenwen's avatar
zhuwenwen committed
184
185
186
187
188
189
void silu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul_opt(torch::Tensor& out, torch::Tensor& input);

190
191
192
void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
193

194
195
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

zhuwenwen's avatar
zhuwenwen committed
196
197
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
                            int64_t block_size, torch::Tensor& input_tokens,
                            torch::Tensor& sampled_token_ids,
                            torch::Tensor& input_positions,
                            torch::Tensor& seq_lens,
                            torch::Tensor& slot_mapping,
                            torch::Tensor& block_tables);

void advance_step_flashinfer(
    int64_t num_seqs, int64_t num_queries, int64_t block_size,
    torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
    torch::Tensor& input_positions, torch::Tensor& seq_lens,
    torch::Tensor& slot_mapping, torch::Tensor& block_tables,
    torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
    torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
213

214
#ifndef USE_ROCM
215
216
217
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
218
                        const std::vector<int64_t>& codebook_partition_sizes,
219
220
                        const std::optional<torch::Tensor>& bias);

221
222
223
torch::Tensor aqlm_dequant(
    const torch::Tensor& codes, const torch::Tensor& codebooks,
    const std::vector<int64_t>& codebook_partition_sizes);
224
225
226

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
227
                       int64_t split_k_iters);
228
229
230

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
231
232
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
233
234
235
236
237

torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                          torch::Tensor& b_scales, torch::Tensor& workspace,
                          int64_t size_m, int64_t size_n, int64_t size_k);

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
namespace machete {

std::vector<std::string> supported_schedules(
    vllm::ScalarTypeTorchPtr const& btype);

torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
                   vllm::ScalarTypeTorchPtr const& btype,
                   c10::optional<torch::Tensor> const& scales,
                   c10::optional<torch::Tensor> const& zeros,
                   c10::optional<int64_t> group_size,
                   c10::optional<torch::Tensor> const& C,
                   c10::optional<double> alpha, c10::optional<double> beta,
                   c10::optional<std::string> schedule);

torch::Tensor prepack_B(torch::Tensor const& B,
                        vllm::ScalarTypeTorchPtr const& btype);

};  // namespace machete

257
258
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);

259
260
261
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                  torch::Tensor& b_meta,
                                  torch::Tensor& b_scales,
262
263
                                  torch::Tensor& workspace,
                                  vllm::ScalarTypeTorchPtr const& b_q_type,
264
265
266
267
                                  int64_t size_m, int64_t size_n,
                                  int64_t size_k);

torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
268
269
                               torch::Tensor& b_scales, torch::Tensor& b_zeros,
                               torch::Tensor& g_idx, torch::Tensor& perm,
270
271
                               torch::Tensor& workspace,
                               vllm::ScalarTypeTorchPtr const& b_q_type,
272
                               int64_t size_m, int64_t size_n, int64_t size_k,
273
274
                               bool is_k_full, bool has_zp,
                               bool use_fp32_reduce);
275
276
277
278
279

torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
                                 int64_t size_k, int64_t size_n,
                                 int64_t num_bits);

280
281
282
283
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
                                      torch::Tensor& perm, c10::SymInt size_k,
                                      c10::SymInt size_n, int64_t num_bits);

284
285
286
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
                                int64_t size_n, int64_t num_bits);

287
288
289
290
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
                                     c10::SymInt size_k, c10::SymInt size_n,
                                     int64_t num_bits);

291
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
292
293
                              int64_t n);

294
295
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
                                  int64_t type, int64_t row);
296

297
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
298
299
                              int64_t row);

300
301
302
303
304
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                              torch::Tensor& b_scales, torch::Tensor& workspace,
                              int64_t num_bits, int64_t size_m, int64_t size_n,
                              int64_t size_k);

305
306
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

307
308
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
309
310
                       torch::Tensor const& b_scales,
                       c10::optional<torch::Tensor> const& bias);
311

312
313
314
315
316
317
318
319
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
                           c10::optional<torch::Tensor> const& azp,
                           c10::optional<torch::Tensor> const& bias);

320
321
322
323
324
325
326
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
                              torch::Tensor const& b_q_weight,
                              torch::Tensor const& s_tok,
                              torch::Tensor const& s_ch,
                              torch::Tensor const& s_group,
                              torch::Tensor& workspace, int64_t size_m,
                              int64_t size_n, int64_t size_k);
327
#endif
328

329
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
330
331
                              torch::Tensor const& scale,
                              c10::optional<torch::Tensor> const& azp);
332

333
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
334
335
                               torch::Tensor& scales,
                               c10::optional<torch::Tensor> const& azp);
336

337
338
339
340
// torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
//                         torch::Tensor b_gptq_qzeros,
//                         torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
//                         bool use_exllama, int64_t bit);
341

342
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
343

344
345
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//                              torch::Tensor const& scale);
346

347
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
zhuwenwen's avatar
zhuwenwen committed
348
//                               torch::Tensor& scale);
349

350
351
352
// void dynamic_per_token_scaled_fp8_quant(
//     torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
//     c10::optional<torch::Tensor> const& scale_ub);
353

354
355
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
356
357
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);
358

359
360
361
362
363
364
365
366
367
std::vector<torch::Tensor> selective_scan_fwd(
    const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
    const torch::Tensor& B, const torch::Tensor& C,
    const c10::optional<torch::Tensor>& D_,
    const c10::optional<torch::Tensor>& z_,
    const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
    const c10::optional<torch::Tensor>& index_,
    const c10::optional<torch::Tensor>& x);

368
369
370
371
at::Tensor causal_conv1d_update(
    const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
    const c10::optional<at::Tensor>& bias, bool silu_activation,
    const c10::optional<at::Tensor>& conv_state_indices);
372
373
374
375
376
377
378
379

at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
                             const c10::optional<at::Tensor>& bias_,
                             const c10::optional<at::Tensor>& seq_idx_,
                             const c10::optional<at::Tensor>& initial_states_,
                             const c10::optional<at::Tensor>& final_states_out_,
                             bool silu_activation);

380
#ifndef USE_ROCM
381
using fptr_t = int64_t;
382
383
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
                      const std::vector<std::string>& handles,
384
                      const std::vector<int64_t>& offsets, int64_t rank,
385
386
387
388
                      bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
                      torch::Tensor& out);
389
void dispose(fptr_t _fa);
390
int64_t meta_size();
391
392
393
void register_buffer(fptr_t _fa, torch::Tensor& t,
                     const std::vector<std::string>& handles,
                     const std::vector<int64_t>& offsets);
394
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
395
396
397
    fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
                            const std::vector<std::vector<int64_t>>& offsets);
398
#endif