ops.h 13.9 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
#include "core/scalar_type.hpp"

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <vector>

torch::Tensor weak_ref_tensor(torch::Tensor& tensor) {
  // Ensure tensor is on CUDA
  if (!tensor.is_cuda()) {
    throw std::runtime_error("Tensor must be on CUDA device");
  }

  // Get the raw data pointer
  void* data_ptr = tensor.data_ptr();

  // Get tensor sizes and strides
  std::vector<int64_t> sizes = tensor.sizes().vec();
  std::vector<int64_t> strides = tensor.strides().vec();

  // Get tensor options (dtype, device)
  auto options = tensor.options();

  // Create a new tensor from the raw data pointer
  auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options);

  return new_tensor;
}

32
33
void paged_attention_v1(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
34
35
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
36
    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
37
38
39
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
40
41
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
42

43
44
45
void paged_attention_v2(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
46
47
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
48
    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
49
50
51
    const std::string& kv_cache_dtype, torch::Tensor& k_scale,
    torch::Tensor& v_scale, const int64_t tp_rank,
    const int64_t blocksparse_local_blocks,
52
53
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
54

55
56
57
58
59
60
61
62
63
#ifndef USE_ROCM
void merge_attn_states(torch::Tensor& output,
                       std::optional<torch::Tensor> output_lse,
                       const torch::Tensor& prefix_output,
                       const torch::Tensor& prefix_lse,
                       const torch::Tensor& suffix_output,
                       const torch::Tensor& suffix_lse);
#endif

64
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
65
              double epsilon);
66
67

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
68
                        torch::Tensor& weight, double epsilon);
69

70
71
72
73
74
75
76
77
78
79
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
                               torch::Tensor& weight, torch::Tensor& scale,
                               double epsilon);

void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
                                         torch::Tensor& input,
                                         torch::Tensor& residual,
                                         torch::Tensor& weight,
                                         torch::Tensor& scale, double epsilon);

80
81
82
83
84
85
86
87
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
                                      torch::Tensor const& input,
                                      torch::Tensor const& weight,
                                      torch::Tensor& scales,
                                      double const epsilon,
                                      std::optional<torch::Tensor> scale_ub,
                                      std::optional<torch::Tensor> residual);

88
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
89
                      torch::Tensor& key, int64_t head_size,
90
91
92
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
93
                              torch::Tensor& key, int64_t head_size,
94
                              torch::Tensor& cos_sin_cache, bool is_neox,
95
                              int64_t rot_dim,
96
97
98
99
                              torch::Tensor& cos_sin_cache_offsets);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

100
101
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

102
103
104
105
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

106
107
108
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
                     double threshold);

109
110
111
void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
112

113
114
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
                            int64_t block_size, torch::Tensor& input_tokens,
                            torch::Tensor& sampled_token_ids,
                            torch::Tensor& input_positions,
                            torch::Tensor& seq_lens,
                            torch::Tensor& slot_mapping,
                            torch::Tensor& block_tables);

void advance_step_flashinfer(
    int64_t num_seqs, int64_t num_queries, int64_t block_size,
    torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
    torch::Tensor& input_positions, torch::Tensor& seq_lens,
    torch::Tensor& slot_mapping, torch::Tensor& block_tables,
    torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
    torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
130

131
132
133
134
135
136
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
                        torch::Tensor const& q_pe,
                        torch::Tensor const& kv_c_and_k_pe_cache,
                        torch::Tensor const& seq_lens,
                        torch::Tensor const& page_table, double scale);

137
138
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);

139
#ifndef USE_ROCM
140
141
142
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
143
                        const std::vector<int64_t>& codebook_partition_sizes,
144
145
                        const std::optional<torch::Tensor>& bias);

146
147
148
torch::Tensor aqlm_dequant(
    const torch::Tensor& codes, const torch::Tensor& codebooks,
    const std::vector<int64_t>& codebook_partition_sizes);
149
150
151

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
152
                       int64_t split_k_iters);
153
154
155

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
156
157
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
158

159
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
160
#endif
161

162
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
163
164
                              int64_t n,
                              std::optional<at::ScalarType> const& dtype);
165

166
167
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
                                  int64_t type, int64_t row);
168

169
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
170
171
                              int64_t row);

172
173
174
175
176
177
178
179
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
                          torch::Tensor sorted_token_ids,
                          torch::Tensor expert_ids,
                          torch::Tensor num_tokens_post_padded, int64_t type,
                          int64_t row, int64_t top_k, int64_t tokens);

int64_t ggml_moe_get_block_size(int64_t type);

180
#ifndef USE_ROCM
181
182
183
184

bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
185
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
186

187
188
189
190
191
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
                           torch::Tensor const& B, torch::Tensor const& A_sf,
                           torch::Tensor const& B_sf,
                           torch::Tensor const& alpha);

192
193
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
194
                       torch::Tensor const& b_scales,
195
                       std::optional<torch::Tensor> const& bias);
196

197
198
199
200
201
202
203
204
205
206
207
208
209
void cutlass_moe_mm(
    torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
    torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
    torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
    torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
    torch::Tensor const& b_strides, torch::Tensor const& c_strides);

void get_cutlass_moe_mm_data(
    const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
    torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
    torch::Tensor& input_permutation, torch::Tensor& output_permutation,
    const int64_t num_experts, const int64_t n, const int64_t k);

210
211
212
213
214
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
215
216
                           std::optional<torch::Tensor> const& azp,
                           std::optional<torch::Tensor> const& bias);
217

218
219
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);

220
221
222
223
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
                              torch::Tensor const& b, torch::Tensor const& e,
                              torch::Tensor const& a_scales,
                              torch::Tensor const& b_scales,
224
                              std::optional<torch::Tensor> const& bias);
225

226
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
227
228
229
230

void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
                      torch::Tensor& output_scale,
                      torch::Tensor const& input_scale);
231
#endif
232

233
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
234
                              torch::Tensor const& scale,
235
                              std::optional<torch::Tensor> const& azp);
236

237
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
238
                               torch::Tensor& scales,
239
                               std::optional<torch::Tensor> const& azp);
240

241
242
243
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
244
                        bool use_exllama, int64_t bit);
245

246
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
247

248
249
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
                             torch::Tensor const& scale);
250

251
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
252
253
                              torch::Tensor& scale);

254
255
void dynamic_per_token_scaled_fp8_quant(
    torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
256
    std::optional<torch::Tensor> const& scale_ub);
257

258
259
260
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
                        const torch::Tensor& A, const torch::Tensor& B,
                        const torch::Tensor& C,
261
262
263
                        const std::optional<torch::Tensor>& D_,
                        const std::optional<torch::Tensor>& z_,
                        const std::optional<torch::Tensor>& delta_bias_,
264
                        bool delta_softplus,
265
266
267
                        const std::optional<torch::Tensor>& query_start_loc,
                        const std::optional<torch::Tensor>& cache_indices,
                        const std::optional<torch::Tensor>& has_initial_state,
268
269
270
271
                        const torch::Tensor& ssm_states, int64_t pad_slot_id);

void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
                          const at::Tensor& weight,
272
                          const std::optional<at::Tensor>& bias_,
273
                          bool silu_activation,
274
275
                          const std::optional<at::Tensor>& cache_seqlens_,
                          const std::optional<at::Tensor>& conv_state_indices_,
276
277
278
                          int64_t pad_slot_id);

void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
279
280
281
282
283
                       const std::optional<at::Tensor>& bias_,
                       const std::optional<at::Tensor>& conv_states,
                       const std::optional<at::Tensor>& query_start_loc,
                       const std::optional<at::Tensor>& cache_indices,
                       const std::optional<at::Tensor>& has_initial_state,
284
                       bool silu_activation, int64_t pad_slot_id);
285

286
using fptr_t = int64_t;
287
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
288
289
                      torch::Tensor& rank_data, int64_t rank,
                      bool fully_connected);
290
291
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
                fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
292
void dispose(fptr_t _fa);
293
int64_t meta_size();
294
295
296
297
298
void register_buffer(fptr_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa,
                            const std::vector<std::vector<int64_t>>& handles,
299
                            const std::vector<std::vector<int64_t>>& offsets);
300
301
302
303
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
    int64_t size);
int64_t open_mem_handle(torch::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);