ops.h 16.7 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
#include "core/scalar_type.hpp"

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <vector>

torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
  // Ensure tensor is on CUDA
  if (!tensor.is_cuda()) {
    throw std::runtime_error("Tensor must be on CUDA device");
  }

  // Get the raw data pointer
  void* data_ptr = tensor.data_ptr();

  // Get tensor sizes and strides
  std::vector<int64_t> sizes = tensor.sizes().vec();
  std::vector<int64_t> strides = tensor.strides().vec();

  // Get tensor options (dtype, device)
  auto options = tensor.options();

  // Create a new tensor from the raw data pointer
  auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);

  return new_tensor;
}

32
33
void paged_attention_v1(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
34
35
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
36
    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
37
38
39
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
40
41
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
42

43
44
45
void paged_attention_v2(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
46
47
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
48
    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
49
50
51
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
52
53
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
54

55
56
57
58
59
60
61
#ifndef USE_ROCM
void merge_attn_states(torch::Tensor& output,
                       std::optional<torch::Tensor> output_lse,
                       const torch::Tensor& prefix_output,
                       const torch::Tensor& prefix_lse,
                       const torch::Tensor& suffix_output,
                       const torch::Tensor& suffix_lse);
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

void convert_vertical_slash_indexes(
    torch::Tensor& block_count,      // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& block_offset,     // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
    torch::Tensor& column_count,     // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& column_index,     // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
    torch::Tensor q_seqlens,         // [BATCH, ]
    torch::Tensor kv_seqlens,        // [BATCH, ]
    torch::Tensor vertical_indexes,  // [BATCH, N_HEADS, NNZ_V]
    torch::Tensor slash_indexes,     // [BATCH, N_HEADS, NNZ_S]
    int64_t context_size, int64_t block_size_M, int64_t block_size_N,
    bool causal);

void convert_vertical_slash_indexes_mergehead(
    torch::Tensor& block_count,            // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& block_offset,           // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
    torch::Tensor& column_count,           // [BATCH, N_HEADS, NUM_ROWS]
    torch::Tensor& column_index,           // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
    torch::Tensor q_seqlens,               // [BATCH, ]
    torch::Tensor kv_seqlens,              // [BATCH, ]
    torch::Tensor vertical_indexes,        // [BATCH, N_HEADS, NNZ_V]
    torch::Tensor slash_indexes,           // [BATCH, N_HEADS, NNZ_S]
    torch::Tensor vertical_indices_count,  // [N_HEADS, ]
    torch::Tensor slash_indices_count, int64_t context_size,
    int64_t block_size_M, int64_t block_size_N, bool causal);
87
88
#endif

89
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
90
              double epsilon);
91
92

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
93
                        torch::Tensor& weight, double epsilon);
94

95
96
97
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
               torch::Tensor& bias, double epsilon);

98
99
100
101
102
void apply_repetition_penalties_(torch::Tensor& logits,
                                 const torch::Tensor& prompt_mask,
                                 const torch::Tensor& output_mask,
                                 const torch::Tensor& repetition_penalties);

103
104
105
// void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
//                                torch::Tensor& weight, torch::Tensor& scale,
//                                double epsilon);
106

107
108
109
110
111
// void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
//                                          torch::Tensor& input,
//                                          torch::Tensor& residual,
//                                          torch::Tensor& weight,
//                                          torch::Tensor& scale, double epsilon);
112

113
114
115
116
117
118
119
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
                                      torch::Tensor const& input,
                                      torch::Tensor const& weight,
                                      torch::Tensor& scales,
                                      double const epsilon,
                                      std::optional<torch::Tensor> scale_ub,
                                      std::optional<torch::Tensor> residual);
120

121
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
122
                      std::optional<torch::Tensor> key, int64_t head_size,
123
124
125
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
126
127
128
                              std::optional<torch::Tensor> key,
                              int64_t head_size, torch::Tensor& cos_sin_cache,
                              bool is_neox, int64_t rot_dim,
129
130
131
132
                              torch::Tensor& cos_sin_cache_offsets);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

zhuwenwen's avatar
zhuwenwen committed
133
134
// void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
//                         torch::Tensor& scale);
135

136
#ifndef USE_ROCM
137
138
139
140
141
142
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
                              torch::Tensor& output_block_scale,
                              torch::Tensor& input,
                              torch::Tensor& input_global_scale);
#endif

143
144
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

145
146
147
148
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

149
150
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
                     double threshold);
151
152
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
                       double alpha = 1.702, double limit = 7.0);
153

154
155
156
void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
157

158
159
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

160
161
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);

162
163
164
165
166
167
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
                        torch::Tensor const& q_pe,
                        torch::Tensor const& kv_c_and_k_pe_cache,
                        torch::Tensor const& seq_lens,
                        torch::Tensor const& page_table, double scale);

168
169
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);

170
#ifndef USE_ROCM
171
172
173

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
174
                       int64_t split_k_iters);
175
176
177

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
178
179
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
180

181
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
182
#endif
183

184
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
185
186
                              int64_t n,
                              std::optional<at::ScalarType> const& dtype);
187

188
189
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
                                  int64_t type, int64_t row);
190

191
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
192
193
                              int64_t row);

194
195
196
197
198
199
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
                          torch::Tensor sorted_token_ids,
                          torch::Tensor expert_ids,
                          torch::Tensor num_tokens_post_padded, int64_t type,
                          int64_t row, int64_t top_k, int64_t tokens);

200
201
202
203
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
                              torch::Tensor topk_ids, int64_t top_k,
                              int64_t type, int64_t row, int64_t tokens);

204
205
int64_t ggml_moe_get_block_size(int64_t type);

206
#ifndef USE_ROCM
207
208
209
210

bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
211
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
212

213
214
215
216
217
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
                           torch::Tensor const& B, torch::Tensor const& A_sf,
                           torch::Tensor const& B_sf,
                           torch::Tensor const& alpha);

218
219
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
220
                       torch::Tensor const& b_scales,
221
                       std::optional<torch::Tensor> const& bias);
222

223
224
225
226
227
void cutlass_moe_mm(
    torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
    torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
    torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
    torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
228
229
    torch::Tensor const& b_strides, torch::Tensor const& c_strides,
    bool per_act_token, bool per_out_ch);
230

231
232
233
234
235
236
void cutlass_fp4_group_mm(
    torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
    const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
    const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
    const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);

237
238
239
240
void get_cutlass_moe_mm_data(
    const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
    torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
    torch::Tensor& input_permutation, torch::Tensor& output_permutation,
241
242
    const int64_t num_experts, const int64_t n, const int64_t k,
    const std::optional<torch::Tensor>& blockscale_offsets);
243

244
245
246
247
248
void get_cutlass_moe_mm_problem_sizes(
    const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
    torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
    const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);

249
250
251
252
253
254
255
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
                                  torch::Tensor& problem_sizes1,
                                  torch::Tensor& problem_sizes2,
                                  const torch::Tensor& expert_num_tokens,
                                  const int64_t num_local_experts,
                                  const int64_t padded_m, const int64_t n,
                                  const int64_t k);
256

257
258
259
260
261
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
262
263
                           std::optional<torch::Tensor> const& azp,
                           std::optional<torch::Tensor> const& bias);
264

265
266
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);

267
268
269
270
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
                              torch::Tensor const& b, torch::Tensor const& e,
                              torch::Tensor const& a_scales,
                              torch::Tensor const& b_scales,
271
                              std::optional<torch::Tensor> const& bias);
272

273
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
274
275
276
277

void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
                      torch::Tensor& output_scale,
                      torch::Tensor const& input_scale);
278
279
280
281
282
283

void scaled_fp4_experts_quant(
    torch::Tensor& output, torch::Tensor& output_scale,
    torch::Tensor const& input, torch::Tensor const& input_global_scale,
    torch::Tensor const& input_offset_by_experts,
    torch::Tensor const& output_scale_offset_by_experts);
284
285
286
287
288

void per_token_group_quant_fp8(const torch::Tensor& input,
                               torch::Tensor& output_q, torch::Tensor& output_s,
                               int64_t group_size, double eps, double fp8_min,
                               double fp8_max, bool scale_ue8m0);
289
290
291
292
293

void per_token_group_quant_int8(const torch::Tensor& input,
                                torch::Tensor& output_q,
                                torch::Tensor& output_s, int64_t group_size,
                                double eps, double int8_min, double int8_max);
294
#endif
295

296
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
297
                              torch::Tensor const& scale,
298
                              std::optional<torch::Tensor> const& azp);
299

300
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
301
                               torch::Tensor& scales,
302
                               std::optional<torch::Tensor> const& azp);
303

304
305
306
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
307
                        bool use_exllama, int64_t bit);
308

309
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
310

311
312
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//                              torch::Tensor const& scale);
313

314
315
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//                               torch::Tensor& scale);
316

317
318
319
// void dynamic_per_token_scaled_fp8_quant(
//     torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
//     std::optional<torch::Tensor> const& scale_ub);
320

321
322
323
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
                        const torch::Tensor& A, const torch::Tensor& B,
                        const torch::Tensor& C,
324
325
326
                        const std::optional<torch::Tensor>& D_,
                        const std::optional<torch::Tensor>& z_,
                        const std::optional<torch::Tensor>& delta_bias_,
327
                        bool delta_softplus,
328
329
330
                        const std::optional<torch::Tensor>& query_start_loc,
                        const std::optional<torch::Tensor>& cache_indices,
                        const std::optional<torch::Tensor>& has_initial_state,
331
332
                        const torch::Tensor& ssm_states, int64_t pad_slot_id);

333
using fptr_t = int64_t;
334
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
335
336
                      torch::Tensor& rank_data, int64_t rank,
                      bool fully_connected);
337
338
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
                fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
339
void dispose(fptr_t _fa);
340
int64_t meta_size();
341
342
343
344
345
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
                            const std::vector<std::vector<int64_t>>& handles,
346
                            const std::vector<std::vector<int64_t>>& offsets);
347
348
349
350
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
    int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
351
352
353
354
355
356
357
358
359
360

#ifdef USE_ROCM
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
                      std::optional<int64_t> qr_max_size = std::nullopt);
void qr_destroy(fptr_t _fa);
torch::Tensor qr_get_handle(fptr_t _fa);
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
                   int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
361
#endif