torch_bindings.cpp 14 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
#include "cache.h"
#include "ops.h"
#include "core/registration.h"

#include <torch/library.h>

// Note: overwrite the external defination for sharing same name between
// libraries use different ISAs.
#define TORCH_EXTENSION_NAME _C

std::string init_cpu_threads_env(const std::string& cpu_ids);

void release_dnnl_matmul_handler(int64_t handler);

int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b,
                                        const torch::Tensor& b_scales,
                                        at::ScalarType output_type,
                                        bool dynamic_act_quant, bool use_azp,
                                        int64_t primitive_cache_size);

void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
                      const torch::Tensor& a_scales,
                      const std::optional<torch::Tensor>& azp,
                      const std::optional<torch::Tensor>& azp_adj,
                      const std::optional<torch::Tensor>& bias,
                      const torch::Tensor& handler_tensor);

int64_t create_onednn_mm_handler(const torch::Tensor& b,
                                 int64_t primitive_cache_size);

void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
               const std::optional<torch::Tensor>& bias,
               const torch::Tensor& handler_tensor);

bool is_onednn_acl_supported();

void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
                        torch::Tensor& kv_cache, double scale,
                        torch::Tensor& block_tables, torch::Tensor& seq_lens);

int64_t init_shm_manager(const std::string& name, const int64_t group_size,
                         const int64_t rank, const int64_t thread_num);

std::string join_shm_manager(int64_t handle, const std::string& name);

void shm_allreduce(int64_t handle, torch::Tensor& data);

void shm_gather(int64_t handle, torch::Tensor& data,
                const std::optional<std::vector<torch::Tensor>>& outputs,
                int64_t dst);

void shm_all_gather(int64_t handle, const torch::Tensor& data,
                    torch::Tensor& output);

void shm_send_tensor_list(int64_t handle,
                          const std::vector<torch::Tensor>& tensor_list,
                          int64_t dst);

std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);

at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
                                const std::optional<at::Tensor>& bias,
                                bool is_vnni);

at::Tensor convert_weight_packed(at::Tensor& weight);

at::Tensor fused_experts_cpu(
    at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2,
    at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace,
    bool use_int8_w8a8, bool use_fp8_w8a16,
    const std::optional<at::Tensor>& w1_scale,
    const std::optional<at::Tensor>& w2_scale,
    const std::optional<std::vector<int64_t>> block_size,
    const std::optional<at::Tensor>& a1_scale,
    const std::optional<at::Tensor>& a2_scale, bool is_vnni);

at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
                                     at::Tensor& scales2,
                                     const std::optional<at::Tensor>& bias,
                                     at::ScalarType out_dtype, bool is_vnni);

torch::Tensor get_scheduler_metadata(
    const int64_t num_req, const int64_t num_heads_q,
    const int64_t num_heads_kv, const int64_t head_dim,
    const torch::Tensor& seq_lens, at::ScalarType dtype,
    const torch::Tensor& query_start_loc, const bool casual,
    const int64_t window_size, const std::string& isa_hint,
    const bool enable_kv_split);

void cpu_attn_reshape_and_cache(const torch::Tensor& key,
                                const torch::Tensor& value,
                                torch::Tensor& key_cache,
                                torch::Tensor& value_cache,
                                const torch::Tensor& slot_mapping,
                                const std::string& isa);

void cpu_attention_with_kv_cache(
    const torch::Tensor& query, const torch::Tensor& key_cache,
    const torch::Tensor& value_cache, torch::Tensor& output,
    const torch::Tensor& query_start_loc, const torch::Tensor& seq_lens,
    const double scale, const bool causal,
    const std::optional<torch::Tensor>& alibi_slopes,
    const int64_t sliding_window_left, const int64_t sliding_window_right,
    const torch::Tensor& block_table, const double softcap,
    const torch::Tensor& scheduler_metadata,
    const std::optional<torch::Tensor>& s_aux);

// Note: just for avoiding importing errors
void placeholder_op() { TORCH_CHECK(false, "Unimplemented"); }

void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
                    torch::Tensor& output, const torch::Tensor& scales,
                    const std::optional<torch::Tensor>& zeros,
                    const std::optional<torch::Tensor>& g_idx,
                    const std::optional<torch::Tensor>& bias,
                    const int64_t pack_factor, const std::string& isa_hint);

void prepack_moe_weight(const torch::Tensor& weight,
                        torch::Tensor& packed_weight, const std::string& isa);

void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
                   const torch::Tensor& w13, const torch::Tensor& w2,
                   const std::optional<torch::Tensor>& w13_bias,
                   const std::optional<torch::Tensor>& w2_bias,
                   const torch::Tensor& topk_weights,
                   const torch::Tensor& topk_id, const bool skip_weighted,
                   const std::string& act, const std::string& isa);

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  // vLLM custom ops

  ops.def(
      "dynamic_4bit_int_moe("
      "Tensor x, Tensor topk_ids, Tensor topk_weights,"
      "Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
      "int group_size, bool apply_router_weight_on_input, int activation_kind"
      ") -> Tensor");

  ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);

  // Activation ops

  // Activation function used in SwiGLU.
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);

  // Activation function used in GeGLU with `none` approximation.
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation.
  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);

  // GELU implementation used in GPT-2.
  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_new", torch::kCPU, &gelu_new);

  // Approximate GELU implementation.
  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_fast", torch::kCPU, &gelu_fast);

  // Quick GELU implementation.
  ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_quick", torch::kCPU, &gelu_quick);

  // Layernorm
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
  ops.def(
      "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
      "()");
  ops.impl("rms_norm", torch::kCPU, &rms_norm);

  // In-place fused Add and RMS Normalization.
  ops.def(
      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);

  // Rotary embedding
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding(Tensor positions, Tensor! query,"
      "                 Tensor!? key, int head_size,"
      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
  ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);

  // Quantization
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
    defined(__powerpc64__)
  // Helper function to release oneDNN handlers
  ops.def("release_dnnl_matmul_handler(int handler) -> ()",
          &release_dnnl_matmul_handler);

  // Create oneDNN GEMM handler
  ops.def(
      "create_onednn_mm_handler(Tensor b, int "
      "primitive_cache_size) -> int",
      &create_onednn_mm_handler);

  // oneDNN GEMM
  ops.def(
      "onednn_mm(Tensor! c, Tensor a, Tensor? bias, "
      "Tensor handler_tensor) -> ()");
  ops.impl("onednn_mm", torch::kCPU, &onednn_mm);

  // Check if oneDNN was built with ACL backend
  ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported);

  // Create oneDNN W8A8 handler
  ops.def(
      "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
      "output_type, bool dynamic_act_quant, bool use_azp, int "
      "primitive_cache_size) -> int",
      &create_onednn_scaled_mm_handler);

  // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
  ops.def(
      "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
      "Tensor? azp_adj, Tensor? bias, Tensor handler_tensor) -> ()");
  ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);

  // Compute int8 quantized tensor for given scaling factor.
  ops.def(
      "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
      "Tensor? azp) -> ()");
  ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);

  // Compute int8 quantized tensor and scaling factor
  ops.def(
      "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
      "Tensor!? azp) -> ()");
  ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
           &dynamic_scaled_int8_quant);
#endif

// SHM CCL
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
  ops.def(
      "init_shm_manager(str name, int group_size, int rank, int thread_num) -> "
      "int",
      &init_shm_manager);
  ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
  ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
  ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
  ops.def(
      "shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
      "()");
  ops.impl("shm_gather", torch::kCPU, &shm_gather);
  ops.def(
      "shm_all_gather(int handle, Tensor data, Tensor! output) -> "
      "()");
  ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather);
  ops.def(
      "shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
      "()");
  ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
  ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
          &shm_recv_tensor_list);
#endif  // #if defined(__AVX512F__) || defined(__aarch64__)

  // sgl-kernels
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
  ops.def(
      "weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? "
      "bias, bool is_vnni) -> Tensor");
  ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
  ops.def("convert_weight_packed(Tensor! weight) -> Tensor");
  ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
  ops.def(
      "fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor "
      "topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool "
      "use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? "
      "block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> "
      "Tensor");
  ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
  ops.def(
      "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, "
      "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
  ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
           &int8_scaled_mm_with_quant);
#endif

  // CPU attention kernels
  ops.def(
      "get_scheduler_metadata(int num_req, int num_heads_q, int num_heads_kv, "
      "int head_dim, Tensor seq_lens, ScalarType dtype, Tensor "
      "query_start_loc, bool casual, int window_size, str isa_hint, bool "
      "enable_kv_split) -> Tensor",
      &get_scheduler_metadata);
  ops.def(
      "cpu_attn_reshape_and_cache(Tensor key, Tensor value, Tensor(a2!) "
      "key_cache, Tensor(a3!) value_cache, Tensor slot_mapping, str "
      "isa) -> ()",
      &cpu_attn_reshape_and_cache);
  ops.def(
      "cpu_attention_with_kv_cache(Tensor query, Tensor key_cache, Tensor "
      "value_cache, Tensor(a3!) output, Tensor query_start_loc, Tensor "
      "seq_lens, float scale, bool causal, Tensor? alibi_slopes, SymInt "
      "sliding_window_left, SymInt sliding_window_right, Tensor block_table, "
      "float softcap, Tensor scheduler_metadata, Tensor? s_aux) -> ()",
      &cpu_attention_with_kv_cache);

  // placeholders
  ops.def("static_scaled_fp8_quant() -> ()", placeholder_op);
  ops.def("dynamic_scaled_fp8_quant() -> ()", placeholder_op);
  ops.def("dynamic_per_token_scaled_fp8_quant() -> ()", placeholder_op);

  // WNA16
#if defined(__AVX512F__)
  ops.def(
      "cpu_gemm_wna16(Tensor input, Tensor q_weight, Tensor(a2!) output, "
      "Tensor scales, Tensor? zeros, Tensor? g_idx, Tensor? bias, SymInt "
      "pack_factor, str isa_hint) -> ()");
  ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif

  // fused moe
#if defined(__AVX512F__)
  ops.def(
      "prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
      "-> ()");
  ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
  ops.def(
      "cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
      "Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
      "bool skip_weighted, "
      "str act, str isa) -> ()");
  ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
  ops.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
  ops.def(
      "mla_decode_kvcache("
      "   Tensor! out, Tensor query, Tensor kv_cache,"
      "   float scale, Tensor block_tables, Tensor seq_lens) -> ()");
  ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)