torch_extension_cpu.cpp 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <ATen/ATen.h>
17
#include <torch/all.h>
18
19
#include <torch/library.h>

blzheng's avatar
blzheng committed
20
#include "sgl_kernel_ops.h"
21
22
23
24
25
#include "shm.h"

// silu_and_mul
at::Tensor silu_and_mul_cpu(at::Tensor& input);

26
27
28
// l2norm
at::Tensor l2norm_cpu(at::Tensor& input, double eps);

29
30
31
32
33
34
35
// rmsnorm
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);

// fused_add_rmsnorm
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);

// topk
36
37
38
39
40
std::tuple<at::Tensor, at::Tensor>
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
std::tuple<at::Tensor, at::Tensor>
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);

41
42
43
44
45
46
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
    at::Tensor& hidden_states,
    at::Tensor& gating_output,
    int64_t topk,
    bool renormalize,
    int64_t num_expert_group,
47
48
49
50
    int64_t topk_group,
    int64_t num_fused_shared_experts,
    std::optional<double> routed_scaling_factor,
    std::optional<at::Tensor> num_token_non_padded);
51
52
53
54
55
56
57
58

std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
    at::Tensor& hidden_states,
    at::Tensor& gating_output,
    at::Tensor& correction_bias,
    int64_t topk,
    bool renormalize,
    int64_t num_expert_group,
59
60
61
62
    int64_t topk_group,
    int64_t num_fused_shared_experts,
    std::optional<double> routed_scaling_factor,
    std::optional<at::Tensor> num_token_non_padded);
63
64
65
66
67

// attention
void decode_attention_cpu(
    at::Tensor& query,
    at::Tensor& k_cache,
68
69
70
71
72
    at::Tensor& v_cache,
    at::Tensor& output,
    at::Tensor& key,
    at::Tensor& value,
    at::Tensor& loc,
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    at::Tensor& attn_logits,
    at::Tensor& req_to_token,
    at::Tensor& req_pool_indices,
    at::Tensor& seq_lens,
    double sm_scale,
    double logit_cap);

void extend_attention_cpu(
    at::Tensor& q_extend,
    at::Tensor& k_extend,
    at::Tensor& v_extend,
    at::Tensor& o_extend,
    at::Tensor& k_buffer,
    at::Tensor& v_buffer,
    at::Tensor& req_to_token,
    at::Tensor& req_pool_indices,
    at::Tensor& seq_lens,
    at::Tensor& extend_seq_lens,
    at::Tensor& extend_start_loc,
    int64_t max_len_extend,
    double sm_scale,
    double logit_cap);

// weight prepack
at::Tensor convert_weight_packed(at::Tensor& weight);

// quant
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);

// gemm
blzheng's avatar
blzheng committed
103
104
at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
105
106
107
108
109
110
111

// igemm
at::Tensor int8_scaled_mm_cpu(
    at::Tensor& mat1,
    at::Tensor& mat2,
    at::Tensor& scales1,
    at::Tensor& scales2,
blzheng's avatar
blzheng committed
112
    const std::optional<at::Tensor>& bias,
113
114
115
    at::ScalarType out_dtype,
    bool is_vnni);

116
117
118
119
120
121
// fp8 gemm
at::Tensor fp8_scaled_mm_cpu(
    at::Tensor& mat1,
    at::Tensor& mat2,
    at::Tensor& scales2,
    std::vector<int64_t> block_size,
blzheng's avatar
blzheng committed
122
    const std::optional<at::Tensor>& bias,
123
124
125
    at::ScalarType out_dtype,
    bool is_vnni);

126
127
128
129
130
// quant + igemm
at::Tensor int8_scaled_mm_with_quant(
    at::Tensor& mat1,
    at::Tensor& mat2,
    at::Tensor& scales2,
blzheng's avatar
blzheng committed
131
    const std::optional<at::Tensor>& bias,
132
133
134
135
    at::ScalarType out_dtype,
    bool is_vnni);

// bmm
blzheng's avatar
blzheng committed
136
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
137
138
139
140
141
142
143
144
145
146

// fused moe
at::Tensor fused_experts_cpu(
    at::Tensor& hidden_states,
    at::Tensor& w1,
    at::Tensor& w2,
    at::Tensor& topk_weights,
    at::Tensor& topk_ids,
    bool inplace,
    bool use_int8_w8a8,
147
    bool use_fp8_w8a16,
blzheng's avatar
blzheng committed
148
149
    const std::optional<at::Tensor>& w1_scale,
    const std::optional<at::Tensor>& w2_scale,
150
    const std::optional<std::vector<int64_t>> block_size,
blzheng's avatar
blzheng committed
151
152
    const std::optional<at::Tensor>& a1_scale,
    const std::optional<at::Tensor>& a2_scale,
153
154
155
156
157
158
159
160
161
162
    bool is_vnni);

at::Tensor shared_expert_cpu(
    at::Tensor& hidden_states,
    at::Tensor& w1,
    at::Tensor& w2,
    at::Tensor& fused_experts_out,
    double routed_scaling_factor,
    bool inplace,
    bool use_int8_w8a8,
163
    bool use_fp8_w8a16,
blzheng's avatar
blzheng committed
164
165
166
167
168
    const std::optional<at::Tensor>& w1_scale,
    const std::optional<at::Tensor>& w2_scale,
    const std::optional<std::vector<int64_t>> block_size,
    const std::optional<at::Tensor>& a1_scale,
    const std::optional<at::Tensor>& a2_scale,
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    bool is_vnni);

// weight absorption
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
    at::Tensor& hidden_states,
    at::Tensor& q_a_proj_weight,
    at::Tensor& q_b_proj_weight,
    at::Tensor& kv_a_proj_weight,
    at::Tensor& w_kc,
    at::Tensor& q_a_layernorm_weight,
    at::Tensor& kv_a_layernorm_weight,
    at::Tensor& positions,
    at::Tensor& cos_sin_cache,
    double eps,
    bool use_int8_w8a8,
184
    bool use_fp8_w8a16,
blzheng's avatar
blzheng committed
185
186
187
    std::optional<at::Tensor> q_a_proj_scale,
    std::optional<at::Tensor> q_b_proj_scale,
    std::optional<at::Tensor> kv_a_proj_scale,
188
189
    bool is_vnni,
    std::optional<std::vector<int64_t>> block_size);
190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
    at::Tensor& hidden_states,
    at::Tensor& qkv_a_proj_weight,
    at::Tensor& q_b_proj_weight,
    at::Tensor& w_kc,
    at::Tensor& q_a_layernorm_weight,
    at::Tensor& kv_a_layernorm_weight,
    at::Tensor& positions,
    at::Tensor& cos_sin_cache,
    double eps,
    bool use_int8_w8a8,
    bool use_fp8_w8a16,
    std::optional<at::Tensor> qkv_a_proj_scale,
    std::optional<at::Tensor> q_b_proj_scale,
    bool is_vnni,
    std::optional<std::vector<int64_t>> block_size,
    int64_t q_lora_rank,
    int64_t kv_lora_rank,
    int64_t qk_rope_head_dim);

211
// shared memory init
blzheng's avatar
blzheng committed
212
void initialize(int64_t size, int64_t rank);
213
214

// shared mmeory all_reduce
blzheng's avatar
blzheng committed
215
216
void shm_allreduce(
    at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op);
217
218

// shared memory all_gather
blzheng's avatar
blzheng committed
219
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
220
221

// rope
222
223
224
225
226
227
228
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
    at::Tensor& positions,
    at::Tensor& query,
    at::Tensor& key,
    int64_t head_size,
    at::Tensor& cos_sin_cache,
    bool is_neox);
229

blzheng's avatar
blzheng committed
230
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
231
  // activation
blzheng's avatar
blzheng committed
232
233
  m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
  m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
234
235

  // norm
blzheng's avatar
blzheng committed
236
237
  m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
  m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
238
239
  m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
  m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
blzheng's avatar
blzheng committed
240
241
  m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
  m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
242
243

  // topk
244
245
246
247
  m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
  m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu);
  m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
  m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
blzheng's avatar
blzheng committed
248
249
  m.def(
      "grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
250
251
      "int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, Tensor? num_token_non_padded) -> "
      "(Tensor, Tensor)");
blzheng's avatar
blzheng committed
252
  m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu);
253
254

  // biased group topk
blzheng's avatar
blzheng committed
255
256
  m.def(
      "biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool "
257
258
      "renormalize, int num_expert_group, int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, "
      "Tensor? num_token_non_padded) -> (Tensor, Tensor)");
blzheng's avatar
blzheng committed
259
  m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu);
260
261

  // decode
blzheng's avatar
blzheng committed
262
  m.def(
263
264
265
      "decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
      "Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
      "float logit_cap) -> ()");
blzheng's avatar
blzheng committed
266
  m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
267
268

  // extend
blzheng's avatar
blzheng committed
269
270
271
272
273
  m.def(
      "extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, "
      "Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
      "extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
  m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
274
275

  // weight prepack
blzheng's avatar
blzheng committed
276
277
  m.def("convert_weight_packed(Tensor weight) -> Tensor");
  m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
278
279

  // quant
blzheng's avatar
blzheng committed
280
281
  m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)");
  m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu);
282
283

  // gemm
blzheng's avatar
blzheng committed
284
285
  m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor");
  m.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
286
287

  // igemm
blzheng's avatar
blzheng committed
288
289
290
291
  m.def(
      "int8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales1, Tensor scales2, Tensor? bias, ScalarType "
      "out_dtype, bool is_vnni) -> Tensor");
  m.impl("int8_scaled_mm_cpu", torch::kCPU, &int8_scaled_mm_cpu);
292

293
  // fp8 gemm
blzheng's avatar
blzheng committed
294
295
296
297
  m.def(
      "fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType "
      "out_dtype, bool is_vnni) -> Tensor");
  m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
298

299
300
  // quant + igemm
  m.def(
blzheng's avatar
blzheng committed
301
302
303
      "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool "
      "is_vnni) -> Tensor");
  m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
304
305

  // bmm
blzheng's avatar
blzheng committed
306
307
  m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
  m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
308
309

  // moe
blzheng's avatar
blzheng committed
310
311
  m.def(
      "fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
312
313
      "inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? "
      "a1_scale, Tensor? a2_scale, bool "
blzheng's avatar
blzheng committed
314
315
      "is_vnni) -> Tensor");
  m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
316
317

  // weight absorption
blzheng's avatar
blzheng committed
318
319
320
  m.def(
      "qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
      "kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
321
322
323
      "Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
      "q_b_proj_scale, Tensor? "
      "kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
blzheng's avatar
blzheng committed
324
  m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
325
326
327
328
329
330
331
332
  m.def(
      "qkv_proj_with_rope_fused_weight(Tensor hidden_states, Tensor qkv_a_proj_weight, Tensor q_b_proj_weight, "
      "Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
      "Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? qkv_a_proj_scale, Tensor? "
      "q_b_proj_scale,"
      "bool is_vnni, int[]? block_size, int q_lora_rank, int kv_lora_rank,"
      "int qk_rope_head_dim) -> (Tensor, Tensor, Tensor)");
  m.impl("qkv_proj_with_rope_fused_weight", torch::kCPU, &qkv_proj_with_rope_fused_weight);
333
334

  // shared expert
blzheng's avatar
blzheng committed
335
336
337
338
339
  m.def(
      "shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float "
      "routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? "
      "w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor");
  m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu);
340
341

  // all reduce
blzheng's avatar
blzheng committed
342
343
344
345
346
347
348
349
  m.def("initialize(int size, int rank) -> ()");
  m.impl("initialize", torch::kCPU, &initialize);
  m.def(
      "shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
      "__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()");
  m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
  m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor");
  m.impl("shm_allgather", torch::kCPU, &shm_allgather);
350
351

  // rope
352
353
354
355
  m.def(
      "rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
      "bool is_neox) -> (Tensor, Tensor)");
  m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
356
}
blzheng's avatar
blzheng committed
357
358

REGISTER_EXTENSION(common_ops)