torch_bindings.cpp 7.21 KB
Newer Older
1
2
#include "cache.h"
#include "ops.h"
3
#include "core/registration.h"
4
5
6

#include <torch/library.h>

7
8
9
10
11
std::string init_cpu_threads_env(const std::string& cpu_ids);

void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
                    const torch::Tensor& b, const torch::Tensor& a_scales,
                    const torch::Tensor& b_scales,
12
                    const std::optional<torch::Tensor>& bias);
13

14
15
16
17
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
                        const torch::Tensor& b, const torch::Tensor& a_scales,
                        const torch::Tensor& b_scales,
                        const torch::Tensor& azp_adj,
18
19
                        const std::optional<torch::Tensor>& azp,
                        const std::optional<torch::Tensor>& bias);
20

Thien Tran's avatar
Thien Tran committed
21
22
23
24
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
                        torch::Tensor& kv_cache, double scale,
                        torch::Tensor& block_tables, torch::Tensor& seq_lens);

25
26
27
28
29
30
31
32
33
34
35
36
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  // vLLM custom ops

  // Attention ops
  // Compute the attention between an input query and the cached keys/values
  // using PagedAttention.
  ops.def(
      "paged_attention_v1("
      "    Tensor! out, Tensor query, Tensor key_cache,"
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
37
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
38
      "    int tp_rank, int blocksparse_local_blocks,"
39
40
41
42
43
44
45
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);

  // PagedAttention V2.
  ops.def(
      "paged_attention_v2("
46
47
      "    Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
      "    Tensor! tmp_out, Tensor query, Tensor key_cache,"
48
49
50
      "    Tensor value_cache, int num_kv_heads, float scale,"
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
      "    int max_seq_len, Tensor? alibi_slopes,"
51
      "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
52
      "    int tp_rank, int blocksparse_local_blocks,"
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
      "    int blocksparse_head_sliding_step) -> ()");
  ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);

  // Activation ops

  // Activation function used in SwiGLU.
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);

  // Activation function used in GeGLU with `none` approximation.
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);

  // Activation function used in GeGLU with `tanh` approximation.
  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);

  // GELU implementation used in GPT-2.
  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_new", torch::kCPU, &gelu_new);

  // Approximate GELU implementation.
  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_fast", torch::kCPU, &gelu_fast);

79
80
81
82
  // Quick GELU implementation.
  ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
  ops.impl("gelu_quick", torch::kCPU, &gelu_quick);

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
  // Layernorm
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
  ops.def(
      "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
      "()");
  ops.impl("rms_norm", torch::kCPU, &rms_norm);

  // In-place fused Add and RMS Normalization.
  ops.def(
      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
      "float epsilon) -> ()");
  ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);

  // Rotary embedding
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
  ops.def(
      "rotary_embedding(Tensor positions, Tensor! query,"
      "                 Tensor! key, int head_size,"
      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
  ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
103
104
105
106
107

  // Quantization
#ifdef __AVX512F__
  // Compute int8 quantized tensor for given scaling factor.
  ops.def(
108
109
      "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
      "Tensor? azp) -> ()");
110
  ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
111

112
113
  // Compute int8 quantized tensor and scaling factor
  ops.def(
114
115
      "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
      "Tensor!? azp) -> ()");
116
117
118
119
120
121
122
123
124
  ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
           &dynamic_scaled_int8_quant);
  // W8A8 GEMM, supporting symmetric per-tensor or per-row/column
  // quantization.
  ops.def(
      "cutlass_scaled_mm(Tensor! out, Tensor a,"
      "                  Tensor b, Tensor a_scales,"
      "                  Tensor b_scales, Tensor? bias) -> ()");
  ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
125
126
127
128
129
130
131
132
  // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
  // quantization.
  ops.def(
      "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
      "                  Tensor b, Tensor a_scales,"
      "                  Tensor b_scales, Tensor azp_adj,"
      "                  Tensor? azp, Tensor? bias) -> ()");
  ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
133
#endif
134
135
136
137
138
139
140
141
142
143
144
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
  // Cache ops
  // Swap in (out) the cache blocks from src to dst.
  cache_ops.def(
      "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
  cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);

  // Copy the cache blocks from src to dst.
  cache_ops.def(
145
146
      "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
      "Tensor block_mapping) -> ()");
147
148
149
150
151
152
153
154
  cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);

  // Reshape the key and value tensors and cache them.
  cache_ops.def(
      "reshape_and_cache(Tensor key, Tensor value,"
      "                  Tensor! key_cache, Tensor! value_cache,"
      "                  Tensor slot_mapping,"
      "                  str kv_cache_dtype,"
155
      "                  Tensor k_scale, Tensor v_scale) -> ()");
156
  cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
Thien Tran's avatar
Thien Tran committed
157
158
159
160
161
162
163
164

  cache_ops.def(
      "concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
      "                     Tensor! kv_cache,"
      "                     Tensor slot_mapping,"
      "                     str kv_cache_dtype,"
      "                     Tensor scale) -> ()");
  cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
165
166
}

167
168
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
  // CPU utils
169
  utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
170
171
}

Thien Tran's avatar
Thien Tran committed
172
173
174
175
176
177
178
179
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
  cpu_ops.def(
      "mla_decode_kvcache("
      "   Tensor! out, Tensor query, Tensor kv_cache,"
      "   float scale, Tensor block_tables, Tensor seq_lens) -> ()");
  cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
}

180
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)