ops.h 18.7 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
#include "core/scalar_type.hpp"

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <vector>

torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
  // Ensure tensor is on CUDA
  if (!tensor.is_cuda()) {
    throw std::runtime_error("Tensor must be on CUDA device");
  }

  // Get the raw data pointer
  void* data_ptr = tensor.data_ptr();

  // Get tensor sizes and strides
  std::vector<int64_t> sizes = tensor.sizes().vec();
  std::vector<int64_t> strides = tensor.strides().vec();

  // Get tensor options (dtype, device)
  auto options = tensor.options();

  // Create a new tensor from the raw data pointer
  auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);

  return new_tensor;
}

32
33
void paged_attention_v1(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
34
35
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
36
    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
37
38
39
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
40
41
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
42

43
44
45
void paged_attention_v2(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
46
47
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
48
    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
49
50
51
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
52
53
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
54

55
56
57
58
59
60
void merge_attn_states(torch::Tensor& output,
                       std::optional<torch::Tensor> output_lse,
                       const torch::Tensor& prefix_output,
                       const torch::Tensor& prefix_lse,
                       const torch::Tensor& suffix_output,
                       const torch::Tensor& suffix_lse);
61
#ifndef USE_ROCM
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
void convert_vertical_slash_indexes(
    torch::Tensor& block_count,      // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& block_offset,     // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
    torch::Tensor& column_count,     // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& column_index,     // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
    torch::Tensor q_seqlens,         // [BATCH, ]
    torch::Tensor kv_seqlens,        // [BATCH, ]
    torch::Tensor vertical_indexes,  // [BATCH, N_HEADS, NNZ_V]
    torch::Tensor slash_indexes,     // [BATCH, N_HEADS, NNZ_S]
    int64_t context_size, int64_t block_size_M, int64_t block_size_N,
    bool causal);

void convert_vertical_slash_indexes_mergehead(
    torch::Tensor& block_count,            // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& block_offset,           // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
    torch::Tensor& column_count,           // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& column_index,           // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
    torch::Tensor q_seqlens,               // [BATCH, ]
    torch::Tensor kv_seqlens,              // [BATCH, ]
    torch::Tensor vertical_indexes,        // [BATCH, N_HEADS, NNZ_V]
    torch::Tensor slash_indexes,           // [BATCH, N_HEADS, NNZ_S]
    torch::Tensor vertical_indices_count,  // [N_HEADS, ]
    torch::Tensor slash_indices_count, int64_t context_size,
    int64_t block_size_M, int64_t block_size_N, bool causal);
86
87
#endif

88
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
89
              double epsilon);
90
91

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
92
                        torch::Tensor& weight, double epsilon);
93

94
95
96
97
98
void fused_qk_norm_rope(torch::Tensor& qkv, int64_t num_heads_q,
                        int64_t num_heads_k, int64_t num_heads_v,
                        int64_t head_dim, double eps, torch::Tensor& q_weight,
                        torch::Tensor& k_weight, torch::Tensor& cos_sin_cache,
                        bool is_neox, torch::Tensor& position_ids);
99

100
101
102
103
104
void apply_repetition_penalties_(torch::Tensor& logits,
                                 const torch::Tensor& prompt_mask,
                                 const torch::Tensor& output_mask,
                                 const torch::Tensor& repetition_penalties);

105
106
107
108
109
void top_k_per_row_prefill(const torch::Tensor& logits,
                           const torch::Tensor& rowStarts,
                           const torch::Tensor& rowEnds, torch::Tensor& indices,
                           int64_t numRows, int64_t stride0, int64_t stride1,
                           int64_t topK);
110

111
void top_k_per_row_decode(const torch::Tensor& logits, int64_t next_n,
112
113
114
                          const torch::Tensor& seqLens, torch::Tensor& indices,
                          int64_t numRows, int64_t stride0, int64_t stride1,
                          int64_t topK);
115

116
117
118
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
//                                torch::Tensor& weight, torch::Tensor& scale,
//                                double epsilon);
119

120
121
122
123
124
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
//                                          torch::Tensor& input,
//                                          torch::Tensor& residual,
//                                          torch::Tensor& weight,
//                                          torch::Tensor& scale, double epsilon);
125

126
127
128
129
130
131
132
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
                                      torch::Tensor const& input,
                                      torch::Tensor const& weight,
                                      torch::Tensor& scales,
                                      double const epsilon,
                                      std::optional<torch::Tensor> scale_ub,
                                      std::optional<torch::Tensor> residual);
133

134
135
136
137
138
139
140
void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
                              torch::Tensor const& weight,
                              torch::Tensor& scales, double const epsilon,
                              std::optional<torch::Tensor> scale_ub,
                              std::optional<torch::Tensor> residual,
                              int64_t group_size, bool is_scale_transposed);

141
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
142
                      std::optional<torch::Tensor> key, int64_t head_size,
143
144
145
146
                      torch::Tensor& cos_sin_cache, bool is_neox);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

zhuwenwen's avatar
zhuwenwen committed
147
148
// void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
//                         torch::Tensor& scale);
149

150
#ifndef USE_ROCM
151
152
153
154
155
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
                              torch::Tensor& output_block_scale,
                              torch::Tensor& input,
                              torch::Tensor& input_global_scale);
#endif
156
157
158
159
160
161
// void persistent_masked_m_silu_mul_quant(
//     const at::Tensor& input,   // (E, T, 2*H)
//     const at::Tensor& counts,  // (E)
//     at::Tensor& y_q,           // (E, T, H) [OUT]
//     at::Tensor& y_s,           // (E, T, H//group_size) [OUT]
//     bool use_ue8m0);
162

163
164
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

165
166
167
168
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

169
170
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
                     double threshold);
171
172
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
                       double alpha = 1.702, double limit = 7.0);
173

174
175
176
void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
177

178
179
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

180
181
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);

182
183
184
185
186
187
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
                        torch::Tensor const& q_pe,
                        torch::Tensor const& kv_c_and_k_pe_cache,
                        torch::Tensor const& seq_lens,
                        torch::Tensor const& page_table, double scale);

188
189
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);

190
#ifndef USE_ROCM
191
192
193

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
194
                       int64_t split_k_iters);
195
196
197

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
198
199
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
200

201
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
202
#endif
203

204
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
205
206
                              int64_t n,
                              std::optional<at::ScalarType> const& dtype);
207

208
209
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
                                  int64_t type, int64_t row);
210

211
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
212
213
                              int64_t row);

214
215
216
217
218
219
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
                          torch::Tensor sorted_token_ids,
                          torch::Tensor expert_ids,
                          torch::Tensor num_tokens_post_padded, int64_t type,
                          int64_t row, int64_t top_k, int64_t tokens);

220
221
222
223
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
                              torch::Tensor topk_ids, int64_t top_k,
                              int64_t type, int64_t row, int64_t tokens);

224
225
int64_t ggml_moe_get_block_size(int64_t type);

226
#ifndef USE_ROCM
227
228
229
230

bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
231
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
232

233
234
235
236
237
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
                           torch::Tensor const& B, torch::Tensor const& A_sf,
                           torch::Tensor const& B_sf,
                           torch::Tensor const& alpha);

238
239
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
240
                       torch::Tensor const& b_scales,
241
                       std::optional<torch::Tensor> const& bias);
242

243
244
245
246
247
void cutlass_moe_mm(
    torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
    torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
    torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
    torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
248
249
    torch::Tensor const& b_strides, torch::Tensor const& c_strides,
    bool per_act_token, bool per_out_ch);
250

251
252
253
254
255
256
void cutlass_fp4_group_mm(
    torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
    const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
    const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
    const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);

257
258
259
260
void get_cutlass_moe_mm_data(
    const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
    torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
    torch::Tensor& input_permutation, torch::Tensor& output_permutation,
261
262
    const int64_t num_experts, const int64_t n, const int64_t k,
    const std::optional<torch::Tensor>& blockscale_offsets);
263

264
265
266
void get_cutlass_moe_mm_problem_sizes(
    const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
    torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
267
268
    const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets,
    std::optional<bool> force_swap_ab = std::nullopt);
269

270
271
272
273
274
275
276
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
                                  torch::Tensor& problem_sizes1,
                                  torch::Tensor& problem_sizes2,
                                  const torch::Tensor& expert_num_tokens,
                                  const int64_t num_local_experts,
                                  const int64_t padded_m, const int64_t n,
                                  const int64_t k);
277

278
279
280
281
282
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
283
284
                           std::optional<torch::Tensor> const& azp,
                           std::optional<torch::Tensor> const& bias);
285

286
287
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);

288
289
290
291
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
                              torch::Tensor const& b, torch::Tensor const& e,
                              torch::Tensor const& a_scales,
                              torch::Tensor const& b_scales,
292
                              std::optional<torch::Tensor> const& bias);
293

294
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
295
296
297
298

void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
                      torch::Tensor& output_scale,
                      torch::Tensor const& input_scale);
299
300
301
302
303
304

void scaled_fp4_experts_quant(
    torch::Tensor& output, torch::Tensor& output_scale,
    torch::Tensor const& input, torch::Tensor const& input_global_scale,
    torch::Tensor const& input_offset_by_experts,
    torch::Tensor const& output_scale_offset_by_experts);
305
306
307
308
309

void per_token_group_quant_fp8(const torch::Tensor& input,
                               torch::Tensor& output_q, torch::Tensor& output_s,
                               int64_t group_size, double eps, double fp8_min,
                               double fp8_max, bool scale_ue8m0);
310
311
312
313
314

void per_token_group_quant_int8(const torch::Tensor& input,
                                torch::Tensor& output_q,
                                torch::Tensor& output_s, int64_t group_size,
                                double eps, double int8_min, double int8_max);
315
316
317
318
319
320
321
322

// Fused activation quantisation + DeepGEMM-compatible UE8M0-packed scales.
void per_token_group_quant_8bit_packed(const torch::Tensor& input,
                                       torch::Tensor& output_q,
                                       torch::Tensor& output_s_packed,
                                       int64_t group_size, double eps,
                                       double min_8bit, double max_8bit);

323
#endif
324

325
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
326
                              torch::Tensor const& scale,
327
                              std::optional<torch::Tensor> const& azp);
328

329
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
330
                               torch::Tensor& scales,
331
                               std::optional<torch::Tensor> const& azp);
332

333
334
335
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
336
                        bool use_exllama, bool use_v2_format, int64_t bit);
337

338
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
339

340
341
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//                              torch::Tensor const& scale);
342

343
344
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//                               torch::Tensor& scale);
345

346
347
348
// void dynamic_per_token_scaled_fp8_quant(
//     torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
//     std::optional<torch::Tensor> const& scale_ub);
349

350
351
352
353
354
355
356
357
358
359
360
361
362
void selective_scan_fwd(
    const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
    const torch::Tensor& B, const torch::Tensor& C,
    const std::optional<torch::Tensor>& D_,
    const std::optional<torch::Tensor>& z_,
    const std::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
    const std::optional<torch::Tensor>& query_start_loc,
    const std::optional<torch::Tensor>& cache_indices,
    const std::optional<torch::Tensor>& has_initial_state,
    const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size,
    const std::optional<torch::Tensor>& block_idx_first_scheduled_token,
    const std::optional<torch::Tensor>& block_idx_last_scheduled_token,
    const std::optional<torch::Tensor>& initial_state_idx);
363

364
365
366
367
368
369
torch::Tensor dynamic_4bit_int_moe_cpu(
    torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
    torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
    int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
    int64_t activation_kind);

370
using fptr_t = int64_t;
371
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
372
373
                      torch::Tensor& rank_data, int64_t rank,
                      bool fully_connected);
374
375
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
                fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
376
void dispose(fptr_t _fa);
377
int64_t meta_size();
378
379
380
381
382
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
                            const std::vector<std::vector<int64_t>>& handles,
383
                            const std::vector<std::vector<int64_t>>& offsets);
384
385
386
387
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
    int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
388

389
390
torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace);

391
392
393
394
395
396
397
398
399
#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
                      std::optional<int64_t> qr_max_size = std::nullopt);
void qr_destroy(fptr_t _fa);
torch::Tensor qr_get_handle(fptr_t _fa);
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
                   int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
400
#endif