"scripts/deprecated/test_openai_server.py" did not exist on "ae7ee01a8e59f755d47426c4b08641053b765a89"
torch_extension_cpu.cpp 13.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <ATen/ATen.h>
17
#include <torch/all.h>
18
19
#include <torch/library.h>

blzheng's avatar
blzheng committed
20
#include "sgl_kernel_ops.h"
21
22
23
24
25
#include "shm.h"

// silu_and_mul
at::Tensor silu_and_mul_cpu(at::Tensor& input);

26
27
28
29
// gelu_and_mul
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input);
at::Tensor gelu_and_mul_cpu(const at::Tensor& input);

30
31
32
// l2norm
at::Tensor l2norm_cpu(at::Tensor& input, double eps);

33
34
35
36
37
38
39
// rmsnorm
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);

// fused_add_rmsnorm
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);

// topk
40
41
42
43
44
std::tuple<at::Tensor, at::Tensor>
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
std::tuple<at::Tensor, at::Tensor>
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);

45
46
47
48
49
50
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
    at::Tensor& hidden_states,
    at::Tensor& gating_output,
    int64_t topk,
    bool renormalize,
    int64_t num_expert_group,
51
52
53
54
    int64_t topk_group,
    int64_t num_fused_shared_experts,
    std::optional<double> routed_scaling_factor,
    std::optional<at::Tensor> num_token_non_padded);
55
56
57
58
59
60
61
62

std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
    at::Tensor& hidden_states,
    at::Tensor& gating_output,
    at::Tensor& correction_bias,
    int64_t topk,
    bool renormalize,
    int64_t num_expert_group,
63
64
65
66
    int64_t topk_group,
    int64_t num_fused_shared_experts,
    std::optional<double> routed_scaling_factor,
    std::optional<at::Tensor> num_token_non_padded);
67
68
69
70
71

// attention
void decode_attention_cpu(
    at::Tensor& query,
    at::Tensor& k_cache,
72
73
74
75
76
    at::Tensor& v_cache,
    at::Tensor& output,
    at::Tensor& key,
    at::Tensor& value,
    at::Tensor& loc,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    at::Tensor& attn_logits,
    at::Tensor& req_to_token,
    at::Tensor& req_pool_indices,
    at::Tensor& seq_lens,
    double sm_scale,
    double logit_cap);

void extend_attention_cpu(
    at::Tensor& q_extend,
    at::Tensor& k_extend,
    at::Tensor& v_extend,
    at::Tensor& o_extend,
    at::Tensor& k_buffer,
    at::Tensor& v_buffer,
    at::Tensor& req_to_token,
    at::Tensor& req_pool_indices,
    at::Tensor& seq_lens,
    at::Tensor& extend_seq_lens,
    at::Tensor& extend_start_loc,
    int64_t max_len_extend,
    double sm_scale,
    double logit_cap);

// weight prepack
at::Tensor convert_weight_packed(at::Tensor& weight);

// quant
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);

// gemm
blzheng's avatar
blzheng committed
107
108
at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
109
110
111
112
113
114
115

// igemm
at::Tensor int8_scaled_mm_cpu(
    at::Tensor& mat1,
    at::Tensor& mat2,
    at::Tensor& scales1,
    at::Tensor& scales2,
blzheng's avatar
blzheng committed
116
    const std::optional<at::Tensor>& bias,
117
118
119
    at::ScalarType out_dtype,
    bool is_vnni);

120
121
122
123
124
125
// fp8 gemm
at::Tensor fp8_scaled_mm_cpu(
    at::Tensor& mat1,
    at::Tensor& mat2,
    at::Tensor& scales2,
    std::vector<int64_t> block_size,
blzheng's avatar
blzheng committed
126
    const std::optional<at::Tensor>& bias,
127
128
129
    at::ScalarType out_dtype,
    bool is_vnni);

130
131
132
133
134
// quant + igemm
at::Tensor int8_scaled_mm_with_quant(
    at::Tensor& mat1,
    at::Tensor& mat2,
    at::Tensor& scales2,
blzheng's avatar
blzheng committed
135
    const std::optional<at::Tensor>& bias,
136
137
138
139
    at::ScalarType out_dtype,
    bool is_vnni);

// bmm
blzheng's avatar
blzheng committed
140
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
141
142
143
144
145
146
147
148
149
150

// fused moe
at::Tensor fused_experts_cpu(
    at::Tensor& hidden_states,
    at::Tensor& w1,
    at::Tensor& w2,
    at::Tensor& topk_weights,
    at::Tensor& topk_ids,
    bool inplace,
    bool use_int8_w8a8,
151
    bool use_fp8_w8a16,
blzheng's avatar
blzheng committed
152
153
    const std::optional<at::Tensor>& w1_scale,
    const std::optional<at::Tensor>& w2_scale,
154
    const std::optional<std::vector<int64_t>> block_size,
blzheng's avatar
blzheng committed
155
156
    const std::optional<at::Tensor>& a1_scale,
    const std::optional<at::Tensor>& a2_scale,
157
158
159
160
161
162
163
164
165
166
    bool is_vnni);

at::Tensor shared_expert_cpu(
    at::Tensor& hidden_states,
    at::Tensor& w1,
    at::Tensor& w2,
    at::Tensor& fused_experts_out,
    double routed_scaling_factor,
    bool inplace,
    bool use_int8_w8a8,
167
    bool use_fp8_w8a16,
blzheng's avatar
blzheng committed
168
169
170
171
172
    const std::optional<at::Tensor>& w1_scale,
    const std::optional<at::Tensor>& w2_scale,
    const std::optional<std::vector<int64_t>> block_size,
    const std::optional<at::Tensor>& a1_scale,
    const std::optional<at::Tensor>& a2_scale,
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    bool is_vnni);

// weight absorption
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
    at::Tensor& hidden_states,
    at::Tensor& q_a_proj_weight,
    at::Tensor& q_b_proj_weight,
    at::Tensor& kv_a_proj_weight,
    at::Tensor& w_kc,
    at::Tensor& q_a_layernorm_weight,
    at::Tensor& kv_a_layernorm_weight,
    at::Tensor& positions,
    at::Tensor& cos_sin_cache,
    double eps,
    bool use_int8_w8a8,
188
    bool use_fp8_w8a16,
blzheng's avatar
blzheng committed
189
190
191
    std::optional<at::Tensor> q_a_proj_scale,
    std::optional<at::Tensor> q_b_proj_scale,
    std::optional<at::Tensor> kv_a_proj_scale,
192
193
    bool is_vnni,
    std::optional<std::vector<int64_t>> block_size);
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
    at::Tensor& hidden_states,
    at::Tensor& qkv_a_proj_weight,
    at::Tensor& q_b_proj_weight,
    at::Tensor& w_kc,
    at::Tensor& q_a_layernorm_weight,
    at::Tensor& kv_a_layernorm_weight,
    at::Tensor& positions,
    at::Tensor& cos_sin_cache,
    double eps,
    bool use_int8_w8a8,
    bool use_fp8_w8a16,
    std::optional<at::Tensor> qkv_a_proj_scale,
    std::optional<at::Tensor> q_b_proj_scale,
    bool is_vnni,
    std::optional<std::vector<int64_t>> block_size,
    int64_t q_lora_rank,
    int64_t kv_lora_rank,
    int64_t qk_rope_head_dim);

215
// shared memory init
blzheng's avatar
blzheng committed
216
void initialize(int64_t size, int64_t rank);
217
218

// shared mmeory all_reduce
219
void shm_allreduce(at::Tensor& data, int64_t op);
220
221

// shared memory all_gather
222
at::Tensor shm_allgather(at::Tensor& data, int64_t dim);
223
224

// rope
225
226
227
228
229
230
231
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
    at::Tensor& positions,
    at::Tensor& query,
    at::Tensor& key,
    int64_t head_size,
    at::Tensor& cos_sin_cache,
    bool is_neox);
232

233
234
235
// CPU and memory binding
std::string init_cpu_threads_env(const std::string& cpu_ids);

blzheng's avatar
blzheng committed
236
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
237
  // activation
blzheng's avatar
blzheng committed
238
239
  m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
  m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
240
241
242
243
  m.def("gelu_tanh_and_mul_cpu(Tensor input) -> Tensor");
  m.impl("gelu_tanh_and_mul_cpu", torch::kCPU, &gelu_tanh_and_mul_cpu);
  m.def("gelu_and_mul_cpu(Tensor input) -> Tensor");
  m.impl("gelu_and_mul_cpu", torch::kCPU, &gelu_and_mul_cpu);
244
245

  // norm
blzheng's avatar
blzheng committed
246
247
  m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
  m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
248
249
  m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
  m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
250
  m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor residual, Tensor weight, float eps) -> ()");
blzheng's avatar
blzheng committed
251
  m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
252
253

  // topk
254
255
256
257
  m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
  m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu);
  m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
  m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
blzheng's avatar
blzheng committed
258
259
  m.def(
      "grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
260
261
      "int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, Tensor? num_token_non_padded) -> "
      "(Tensor, Tensor)");
blzheng's avatar
blzheng committed
262
  m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu);
263
264

  // biased group topk
blzheng's avatar
blzheng committed
265
266
  m.def(
      "biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool "
267
268
      "renormalize, int num_expert_group, int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, "
      "Tensor? num_token_non_padded) -> (Tensor, Tensor)");
blzheng's avatar
blzheng committed
269
  m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu);
270
271

  // decode
blzheng's avatar
blzheng committed
272
  m.def(
273
      "decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor(a!) output, Tensor key, Tensor value, "
274
275
      "Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
      "float logit_cap) -> ()");
blzheng's avatar
blzheng committed
276
  m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
277
278

  // extend
blzheng's avatar
blzheng committed
279
  m.def(
280
      "extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor(a!) o_extend, Tensor k_buffer, "
blzheng's avatar
blzheng committed
281
282
283
      "Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
      "extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
  m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
284
285

  // weight prepack
blzheng's avatar
blzheng committed
286
287
  m.def("convert_weight_packed(Tensor weight) -> Tensor");
  m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
288
289

  // quant
blzheng's avatar
blzheng committed
290
291
  m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)");
  m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu);
292
293

  // gemm
blzheng's avatar
blzheng committed
294
295
  m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor");
  m.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
296
297

  // igemm
blzheng's avatar
blzheng committed
298
299
300
301
  m.def(
      "int8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales1, Tensor scales2, Tensor? bias, ScalarType "
      "out_dtype, bool is_vnni) -> Tensor");
  m.impl("int8_scaled_mm_cpu", torch::kCPU, &int8_scaled_mm_cpu);
302

303
  // fp8 gemm
blzheng's avatar
blzheng committed
304
305
306
307
  m.def(
      "fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType "
      "out_dtype, bool is_vnni) -> Tensor");
  m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
308

309
310
  // quant + igemm
  m.def(
blzheng's avatar
blzheng committed
311
312
313
      "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool "
      "is_vnni) -> Tensor");
  m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
314
315

  // bmm
316
  m.def("bmm_cpu(Tensor(a!) out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
blzheng's avatar
blzheng committed
317
  m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
318
319

  // moe
blzheng's avatar
blzheng committed
320
321
  m.def(
      "fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
322
323
      "inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? "
      "a1_scale, Tensor? a2_scale, bool "
blzheng's avatar
blzheng committed
324
325
      "is_vnni) -> Tensor");
  m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
326
327

  // weight absorption
blzheng's avatar
blzheng committed
328
329
330
  m.def(
      "qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
      "kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
331
332
333
      "Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
      "q_b_proj_scale, Tensor? "
      "kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
blzheng's avatar
blzheng committed
334
  m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
335
336
337
338
339
340
341
342
  m.def(
      "qkv_proj_with_rope_fused_weight(Tensor hidden_states, Tensor qkv_a_proj_weight, Tensor q_b_proj_weight, "
      "Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
      "Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? qkv_a_proj_scale, Tensor? "
      "q_b_proj_scale,"
      "bool is_vnni, int[]? block_size, int q_lora_rank, int kv_lora_rank,"
      "int qk_rope_head_dim) -> (Tensor, Tensor, Tensor)");
  m.impl("qkv_proj_with_rope_fused_weight", torch::kCPU, &qkv_proj_with_rope_fused_weight);
343
344

  // shared expert
blzheng's avatar
blzheng committed
345
346
347
348
349
  m.def(
      "shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float "
      "routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? "
      "w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor");
  m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu);
350
351

  // all reduce
blzheng's avatar
blzheng committed
352
  m.def("initialize(int size, int rank) -> ()");
353
  m.def("shm_allreduce(Tensor(a!) data, int reduce_op) -> ()");
blzheng's avatar
blzheng committed
354
  m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
355
  m.def("shm_allgather(Tensor data, int dim) -> Tensor");
blzheng's avatar
blzheng committed
356
  m.impl("shm_allgather", torch::kCPU, &shm_allgather);
357
358

  // rope
359
360
361
362
  m.def(
      "rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
      "bool is_neox) -> (Tensor, Tensor)");
  m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
363
364
365
366
367
368
369

  // CPU and memory binding
  m.def("init_cpu_threads_env(str cpu_ids) -> str");
}

TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) {
  m.impl("init_cpu_threads_env", init_cpu_threads_env);
370
  m.impl("initialize", &initialize);
371
}
blzheng's avatar
blzheng committed
372
373

REGISTER_EXTENSION(common_ops)