ops.h 9.47 KB
Newer Older
1
2
#pragma once

3
#include <optional>
4
#include <torch/library.h>
5

6
7
#include "core/scalar_type.hpp"

8
9
void paged_attention_v1(
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
10
11
12
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
13
14
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
15
16
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
17

18
19
20
void paged_attention_v2(
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
21
22
23
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
24
25
    const std::string& kv_cache_dtype, double k_scale, double v_scale,
    const int64_t tp_rank, const int64_t blocksparse_local_blocks,
26
27
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
    const int64_t blocksparse_head_sliding_step);
28
29

void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
30
              double epsilon);
31
32

void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
33
                        torch::Tensor& weight, double epsilon);
34
35

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
36
                      torch::Tensor& key, int64_t head_size,
37
38
39
                      torch::Tensor& cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
40
                              torch::Tensor& key, int64_t head_size,
41
                              torch::Tensor& cos_sin_cache, bool is_neox,
42
                              int64_t rot_dim,
43
44
45
46
47
48
49
50
51
52
53
                              torch::Tensor& cos_sin_cache_offsets);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_new(torch::Tensor& out, torch::Tensor& input);

void gelu_fast(torch::Tensor& out, torch::Tensor& input);
54

55
56
void gelu_quick(torch::Tensor& out, torch::Tensor& input);

57
58
59
60
61
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
                  torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
                  torch::Tensor& input_positions, torch::Tensor& seq_lens,
                  torch::Tensor& slot_mapping, torch::Tensor& block_tables);

62
#ifndef USE_ROCM
63
64
65
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
66
                        const std::vector<int64_t>& codebook_partition_sizes,
67
68
                        const std::optional<torch::Tensor>& bias);

69
70
71
torch::Tensor aqlm_dequant(
    const torch::Tensor& codes, const torch::Tensor& codebooks,
    const std::vector<int64_t>& codebook_partition_sizes);
72
73
74

torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
75
                       int64_t split_k_iters);
76
77
78

torch::Tensor awq_dequantize(torch::Tensor _kernel,
                             torch::Tensor _scaling_factors,
79
80
                             torch::Tensor _zeros, int64_t split_k_iters,
                             int64_t thx, int64_t thy);
81
82
83
84
85
86
87
88

torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                          torch::Tensor& b_scales, torch::Tensor& workspace,
                          int64_t size_m, int64_t size_n, int64_t size_k);

torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                                  torch::Tensor& b_meta,
                                  torch::Tensor& b_scales,
89
90
                                  torch::Tensor& workspace,
                                  vllm::ScalarTypeTorchPtr const& b_q_type,
91
92
93
94
                                  int64_t size_m, int64_t size_n,
                                  int64_t size_k);

torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
95
96
                               torch::Tensor& b_scales, torch::Tensor& b_zeros,
                               torch::Tensor& g_idx, torch::Tensor& perm,
97
98
                               torch::Tensor& workspace,
                               vllm::ScalarTypeTorchPtr const& b_q_type,
99
                               int64_t size_m, int64_t size_n, int64_t size_k,
100
101
                               bool is_k_full, bool has_zp,
                               bool use_fp32_reduce);
102
103
104
105
106

torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
                                 int64_t size_k, int64_t size_n,
                                 int64_t num_bits);

107
108
109
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
                                int64_t size_n, int64_t num_bits);

110
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
111
112
                              int64_t n);

113
114
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
                                  int64_t type, int64_t row);
115

116
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
117
118
                              int64_t row);

119
120
121
122
123
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                              torch::Tensor& b_scales, torch::Tensor& workspace,
                              int64_t num_bits, int64_t size_m, int64_t size_n,
                              int64_t size_k);

124
125
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

126
127
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                       torch::Tensor const& b, torch::Tensor const& a_scales,
128
129
                       torch::Tensor const& b_scales,
                       c10::optional<torch::Tensor> const& bias);
130

131
132
133
134
135
136
137
138
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
                           torch::Tensor const& b,
                           torch::Tensor const& a_scales,
                           torch::Tensor const& b_scales,
                           torch::Tensor const& azp_adj,
                           c10::optional<torch::Tensor> const& azp,
                           c10::optional<torch::Tensor> const& bias);

139
140
141
142
143
144
145
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
                              torch::Tensor const& b_q_weight,
                              torch::Tensor const& s_tok,
                              torch::Tensor const& s_ch,
                              torch::Tensor const& s_group,
                              torch::Tensor& workspace, int64_t size_m,
                              int64_t size_n, int64_t size_k);
146
#endif
147

148
149
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
                              torch::Tensor const& scale);
150

151
152
153
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
                               torch::Tensor& scales);

154
155
156
157
158
159
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
                     torch::Tensor lookup_table);

torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
                        torch::Tensor b_gptq_qzeros,
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
160
                        bool use_exllama, int64_t bit);
161

162
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
163

164
165
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
                             torch::Tensor const& scale);
166

167
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
168
169
                              torch::Tensor& scale);

170
171
172
void dynamic_per_token_scaled_fp8_quant(
    torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
    c10::optional<torch::Tensor> const& scale_ub);
173

174
175
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
176
177
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);
178
179

#ifndef USE_ROCM
180
using fptr_t = int64_t;
181
182
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
                      const std::vector<std::string>& handles,
183
                      const std::vector<int64_t>& offsets, int64_t rank,
184
                      bool full_nvlink);
185
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
186
                      bool full_nvlink);
187
188
189
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
                      torch::Tensor& out);
190
void dispose(fptr_t _fa);
191
int64_t meta_size();
192
193
194
void register_buffer(fptr_t _fa, torch::Tensor& t,
                     const std::vector<std::string>& handles,
                     const std::vector<int64_t>& offsets);
195
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
196
197
198
    fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
                            const std::vector<std::vector<int64_t>>& offsets);
199
#endif