ops.h 18.9 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
#include "core/scalar_type.hpp"

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <vector>

torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
  // Ensure tensor is on CUDA
  if (!tensor.is_cuda()) {
    throw std::runtime_error("Tensor must be on CUDA device");
  }

  // Get the raw data pointer
  void* data_ptr = tensor.data_ptr();

  // Get tensor sizes and strides
  std::vector<int64_t> sizes = tensor.sizes().vec();
  std::vector<int64_t> strides = tensor.strides().vec();

  // Get tensor options (dtype, device)
  auto options = tensor.options();

  // Create a new tensor from the raw data pointer
  auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);

  return new_tensor;
}

32
void paged_attention_v1(
33
34
35
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
36
    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
37
38
39
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
40
41
42
43
44
45
46
47
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v2(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
48
    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
49
50
51
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v1_opt(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v2_opt(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v1_opt_tc(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v2_opt_tc(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);


// paged_attention with attn_masks
void paged_attention_v1_with_mask(
100
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
101
102
103
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
104
105
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
106
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
107
108
109
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
110

111
void paged_attention_v2_with_mask(
112
113
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
114
115
116
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
117
118
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
119
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
120
121
122
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
123

124
void paged_attention_v1_opt_with_mask(
zhuwenwen's avatar
zhuwenwen committed
125
126
127
128
129
130
131
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
132
133
134
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
zhuwenwen's avatar
zhuwenwen committed
135

136
void paged_attention_v2_opt_with_mask(
zhuwenwen's avatar
zhuwenwen committed
137
138
139
140
141
142
143
144
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
145
146
147
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
zhuwenwen's avatar
zhuwenwen committed
148

149
void paged_attention_v1_opt_tc_with_mask(
150
151
152
153
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
zhuwenwen's avatar
zhuwenwen committed
154
155
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
156
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
157
158
159
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
160

161
void paged_attention_v2_opt_tc_with_mask(
162
163
164
165
166
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
zhuwenwen's avatar
zhuwenwen committed
167
168
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
169
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
170
171
172
    const int64_t blocksparse_head_sliding_step,
    const c10::optional<torch::Tensor>& attn_masks,
    const int64_t attn_masks_stride=0);
173

174

175
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
176
              double epsilon);
177
178

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
179
                        torch::Tensor& weight, double epsilon);
180

zhuwenwen's avatar
zhuwenwen committed
181
182
183
184
185
186
void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
              double epsilon);

void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
                        torch::Tensor& weight, double epsilon);

zhuwenwen's avatar
zhuwenwen committed
187
188
189
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
//                                torch::Tensor& weight, torch::Tensor& scale,
//                                double epsilon);
190

zhuwenwen's avatar
zhuwenwen committed
191
192
193
194
195
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
//                                          torch::Tensor& input,
//                                          torch::Tensor& residual,
//                                          torch::Tensor& weight,
//                                          torch::Tensor& scale, double epsilon);
196

197
198
199
200
201
202
203
204
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
                                      torch::Tensor const& input,
                                      torch::Tensor const& weight,
                                      torch::Tensor& scales,
                                      double const epsilon,
                                      std::optional<torch::Tensor> scale_ub,
                                      std::optional<torch::Tensor> residual);

205
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
206
                      torch::Tensor& key, int64_t head_size,
207
208
209
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
210
                              torch::Tensor& key, int64_t head_size,
211
                              torch::Tensor& cos_sin_cache, bool is_neox,
212
                              int64_t rot_dim,
213
                              torch::Tensor& cos_sin_cache_offsets);
huangwb's avatar
huangwb committed
214
215
216
217
218
219
220
void rotary_embedding_tgi(
  torch::Tensor& query,
  torch::Tensor& key,
  int64_t head_size,
  torch::Tensor& cos_cache,
  torch::Tensor& sin_cache,
  bool is_neox);
221
222
223

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

224
225
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

226
227
228
229
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

zhuwenwen's avatar
zhuwenwen committed
230
231
232
233
234
235
void silu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul_opt(torch::Tensor& out, torch::Tensor& input);

236
237
238
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
                     double threshold);

239
240
241
void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
242

243
244
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

zhuwenwen's avatar
zhuwenwen committed
245
246
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
                            int64_t block_size, torch::Tensor& input_tokens,
                            torch::Tensor& sampled_token_ids,
                            torch::Tensor& input_positions,
                            torch::Tensor& seq_lens,
                            torch::Tensor& slot_mapping,
                            torch::Tensor& block_tables);

void advance_step_flashinfer(
    int64_t num_seqs, int64_t num_queries, int64_t block_size,
    torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
    torch::Tensor& input_positions, torch::Tensor& seq_lens,
    torch::Tensor& slot_mapping, torch::Tensor& block_tables,
    torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
    torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
262

263
#ifndef USE_ROCM
264
265
266
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
267
                        const std::vector<int64_t>& codebook_partition_sizes,
268
269
                        const std::optional<torch::Tensor>& bias);

270
271
272
torch::Tensor aqlm_dequant(
    const torch::Tensor& codes, const torch::Tensor& codebooks,
    const std::vector<int64_t>& codebook_partition_sizes);
273
274
275

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
276
                       int64_t split_k_iters);
277
278
279

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
280
281
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
282

283
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
284
#endif
285

286
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
287
288
                              int64_t n);

289
290
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
                                  int64_t type, int64_t row);
291

292
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
293
294
                              int64_t row);

295
#ifndef USE_ROCM
296
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
297
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
298

299
300
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
301
                       torch::Tensor const& b_scales,
302
                       std::optional<torch::Tensor> const& bias);
303

304
305
306
307
308
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
309
310
                           std::optional<torch::Tensor> const& azp,
                           std::optional<torch::Tensor> const& bias);
311

312
313
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);

314
315
316
317
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
                              torch::Tensor const& b, torch::Tensor const& e,
                              torch::Tensor const& a_scales,
                              torch::Tensor const& b_scales,
318
                              std::optional<torch::Tensor> const& bias);
319
320
321

bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
                                   torch::Tensor& e, torch::Tensor const& a);
322
#endif
323

324
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
325
                              torch::Tensor const& scale,
326
                              std::optional<torch::Tensor> const& azp);
327

328
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
329
                               torch::Tensor& scales,
330
                               std::optional<torch::Tensor> const& azp);
331

332
333
334
335
// torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
//                         torch::Tensor b_gptq_qzeros,
//                         torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
//                         bool use_exllama, int64_t bit);
336

337
// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
338

339
340
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//                              torch::Tensor const& scale);
341

342
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
zhuwenwen's avatar
zhuwenwen committed
343
//                               torch::Tensor& scale);
344

345
346
// void dynamic_per_token_scaled_fp8_quant(
//     torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
zhuwenwen's avatar
zhuwenwen committed
347
//     std::optional<torch::Tensor> const& scale_ub);
348

349
350
351
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
                        const torch::Tensor& A, const torch::Tensor& B,
                        const torch::Tensor& C,
352
353
354
                        const std::optional<torch::Tensor>& D_,
                        const std::optional<torch::Tensor>& z_,
                        const std::optional<torch::Tensor>& delta_bias_,
355
                        bool delta_softplus,
356
357
358
                        const std::optional<torch::Tensor>& query_start_loc,
                        const std::optional<torch::Tensor>& cache_indices,
                        const std::optional<torch::Tensor>& has_initial_state,
359
360
361
362
                        const torch::Tensor& ssm_states, int64_t pad_slot_id);

void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
                          const at::Tensor& weight,
363
                          const std::optional<at::Tensor>& bias_,
364
                          bool silu_activation,
365
366
                          const std::optional<at::Tensor>& cache_seqlens_,
                          const std::optional<at::Tensor>& conv_state_indices_,
367
368
369
                          int64_t pad_slot_id);

void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
370
371
372
373
374
                       const std::optional<at::Tensor>& bias_,
                       const std::optional<at::Tensor>& conv_states,
                       const std::optional<at::Tensor>& query_start_loc,
                       const std::optional<at::Tensor>& cache_indices,
                       const std::optional<at::Tensor>& has_initial_state,
375
                       bool silu_activation, int64_t pad_slot_id);
376

377
#ifndef USE_ROCM
378
using fptr_t = int64_t;
379
380
381
382
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
                      torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
                fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
383
void dispose(fptr_t _fa);
384
int64_t meta_size();
385
386
387
388
389
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
                            const std::vector<std::vector<int64_t>>& handles,
390
                            const std::vector<std::vector<int64_t>>& offsets);
391
#endif