ops.h 10.7 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
#include "core/scalar_type.hpp"

8
void paged_attention_v1(
9
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
10
11
12
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
13
14
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
15
16
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
17
18

void paged_attention_v2(
19
20
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
21
22
23
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
24
25
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
26
27
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
28

zhuwenwen's avatar
zhuwenwen committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
void paged_attention_v1_opt(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

void paged_attention_v2_opt(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);

50
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
51
              double epsilon);
52
53

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
54
                        torch::Tensor& weight, double epsilon);
55

zhuwenwen's avatar
zhuwenwen committed
56
57
58
59
60
61
void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
              double epsilon);

void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
                        torch::Tensor& weight, double epsilon);

62
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
63
                      torch::Tensor& key, int64_t head_size,
64
65
66
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
67
                              torch::Tensor& key, int64_t head_size,
68
                              torch::Tensor& cos_sin_cache, bool is_neox,
69
                              int64_t rot_dim,
70
                              torch::Tensor& cos_sin_cache_offsets);
huangwb's avatar
huangwb committed
71
72
73
74
75
76
77
void rotary_embedding_tgi(
  torch::Tensor& query,
  torch::Tensor& key,
  int64_t head_size,
  torch::Tensor& cos_cache,
  torch::Tensor& sin_cache,
  bool is_neox);
78
79
80
81
82
83
84

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

zhuwenwen's avatar
zhuwenwen committed
85
86
87
88
89
90
void silu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul_opt(torch::Tensor& out, torch::Tensor& input);

91
92
93
void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
94

95
96
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

zhuwenwen's avatar
zhuwenwen committed
97
98
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);

99
100
101
102
103
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
                  torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
                  torch::Tensor& input_positions, torch::Tensor& seq_lens,
                  torch::Tensor& slot_mapping, torch::Tensor& block_tables);

104
#ifndef USE_ROCM
105
106
107
108
109
110
111
112
113
114
115
116
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
                        const torch::Tensor& codebook_partition_sizes,
                        const std::optional<torch::Tensor>& bias);

torch::Tensor aqlm_dequant(const torch::Tensor& codes,
                           const torch::Tensor& codebooks,
                           const torch::Tensor& codebook_partition_sizes);

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
117
                       int64_t split_k_iters);
118
119
120

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
121
122
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
123
124
125
126
127
128
129
130

torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                          torch::Tensor& b_scales, torch::Tensor& workspace,
                          int64_t size_m, int64_t size_n, int64_t size_k);

torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                  torch::Tensor& b_meta,
                                  torch::Tensor& b_scales,
131
132
                                  torch::Tensor& workspace,
                                  vllm::ScalarTypeTorchPtr const& b_q_type,
133
134
135
136
                                  int64_t size_m, int64_t size_n,
                                  int64_t size_k);

torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
137
138
                               torch::Tensor& b_scales, torch::Tensor& b_zeros,
                               torch::Tensor& g_idx, torch::Tensor& perm,
139
140
                               torch::Tensor& workspace,
                               vllm::ScalarTypeTorchPtr const& b_q_type,
141
                               int64_t size_m, int64_t size_n, int64_t size_k,
142
143
                               bool is_k_full, bool has_zp,
                               bool use_fp32_reduce);
144
145
146
147
148

torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
                                 int64_t size_k, int64_t size_n,
                                 int64_t num_bits);

149
150
151
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
                                int64_t size_n, int64_t num_bits);

152
153
154
155
156
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                              torch::Tensor& b_scales, torch::Tensor& workspace,
                              int64_t num_bits, int64_t size_m, int64_t size_n,
                              int64_t size_k);

157
158
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

159
160
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
161
162
                       torch::Tensor const& b_scales,
                       c10::optional<torch::Tensor> const& bias);
163

164
165
166
167
168
169
170
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
                              torch::Tensor const& b_q_weight,
                              torch::Tensor const& s_tok,
                              torch::Tensor const& s_ch,
                              torch::Tensor const& s_group,
                              torch::Tensor& workspace, int64_t size_m,
                              int64_t size_n, int64_t size_k);
171
#endif
172

173
174
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
                              torch::Tensor const& scale);
175

176
177
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
                               torch::Tensor& scales);
178

179
180
181
182
183
184
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
                     torch::Tensor lookup_table);

torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
185
                        bool use_exllama, int64_t bit);
186

187
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
188

189
190
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
//                              torch::Tensor const& scale);
191

192
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
zhuwenwen's avatar
zhuwenwen committed
193
//                               torch::Tensor& scale);
194

195
196
197
// void dynamic_per_token_scaled_fp8_quant(
//     torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
//     c10::optional<torch::Tensor> const& scale_ub);
198

199
200
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
201
202
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);
203
204

#ifndef USE_ROCM
205
using fptr_t = int64_t;
206
207
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
                      const std::vector<std::string>& handles,
208
                      const std::vector<int64_t>& offsets, int64_t rank,
209
                      bool full_nvlink);
210
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
211
                      bool full_nvlink);
212
213
214
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
                      torch::Tensor& out);
215
void dispose(fptr_t _fa);
216
int64_t meta_size();
217
218
219
void register_buffer(fptr_t _fa, torch::Tensor& t,
                     const std::vector<std::string>& handles,
                     const std::vector<int64_t>& offsets);
220
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
221
222
223
    fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
                            const std::vector<std::vector<int64_t>>& offsets);
224
#endif