pybinding.cpp 17.9 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
#include "model.h"
#include "factoryllm.h"

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
namespace pyfastllm{
  // 对接不断更新的后端接口
  // 需优化,减少内存拷贝
  fastllm::Data RMSNorm(const fastllm::Data &input, const fastllm::Data &weight, float eps){
    fastllm::Data output;
    // std::cout<<"run rms norm"<<std::endl;
    fastllm::RMSNorm(input, weight, eps, output);
    // output.Print();
    // std::cout<<"return val"<<std::endl;
    return output;
  }

  fastllm::Data LayerNorm(fastllm::Data &input, fastllm::Data &gamma, fastllm::Data &beta, int axis){
    fastllm::Data output;
    fastllm::LayerNorm(input, gamma, beta, axis, output);
    return output;
  }

  fastllm::Data Linear(fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias){
    fastllm::Data output;
    fastllm::Linear(input, weight, bias, output);
    return output;
  }

  fastllm::Data MatMul(const fastllm::Data &input0, const fastllm::Data &input1, float alpha){
    fastllm::Data output;
    fastllm::MatMul(input0, input1, output, alpha);
    return output;
  }

  fastllm::Data Attention(const fastllm::Data &q, const fastllm::Data &k, const fastllm::Data &v, const fastllm::Data &mask,
                   int group, float scale, int attentionType) {
    fastllm::Data output;
    fastllm::Attention(q, k, v, mask, output, group, scale, attentionType);
    return output;
  }

  fastllm::Data Softmax(const fastllm::Data &input,int axis) {
    fastllm::Data output;
    fastllm::Softmax(input, output, axis);
    return output;
  }

  fastllm::Data Silu(const fastllm::Data &input) {
    fastllm::Data output;
    fastllm::Silu(input, output);
    return output;
  }

  fastllm::Data Gelu(const fastllm::Data &input) {
    fastllm::Data output;
    fastllm::GeluNew(input, output);
    return output;
  }

  fastllm::Data Swiglu(const fastllm::Data &input) {
    fastllm::Data output;
    fastllm::Swiglu(input, output);
    return output;
  }

  fastllm::Data Mul(const fastllm::Data &input, float v){
    fastllm::Data output;
    fastllm::Mul(input, v, output);
    return output;
  }

  fastllm::Data Add(fastllm::Data &input0, const fastllm::Data &input1, float alpha) {
    // fastllm::Data output;
    fastllm::AddTo(input0, input1);
    return input0;
  }

  std::string String(const fastllm::Data &data){
    std::string ss;
    ss += "[";
    int last_dim = data.dims.back();
    int n = data.Count(0) / last_dim, m = last_dim;
    for (int i = 0; i < n; i++) {
      if (i > 0) ss += "\n";
      for (int j = 0; j < 10 && j < m; j++) {
          if (j>0) ss += " ";
          ss += std::to_string(reinterpret_cast<float*>(data.cpuData)[i*m+j]);
      }
      if (m > 10) {
          ss += "... ";
          for (int j = 0; j < 10 && j < m; j++) {
            if (j>0) ss += " ";
            ss += std::to_string(reinterpret_cast<float*>(data.cpuData)[i*m + (m-10+j)]);
          }
      }
      
    }
    ss += "]";
    return ss;
  }
}

zhouxiang's avatar
zhouxiang committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#ifdef PY_API
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/chrono.h>
#include <pybind11/functional.h>
#include <unordered_map>

namespace py = pybind11;
using namespace pybind11::literals;  

// template <typename... Args>
// using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;


using pastKV = std::vector<std::pair<fastllm::Data,fastllm::Data>>;
// PYBIND11_MAKE_OPAQUE(std::vector<std::pair<fastllm::Data,fastllm::Data>>);
PYBIND11_MAKE_OPAQUE(fastllm::Data);

PYBIND11_MODULE(pyfastllm, m) {
  m.doc() = "fastllm python bindings";
  
  py::class_<fastllm::GenerationConfig>(m, "GenerationConfig")
	  .def(py::init<>())
	  .def_readwrite("max_length", &fastllm::GenerationConfig::output_token_limit) 
	  .def_readwrite("last_n", &fastllm::GenerationConfig::last_n) 
	  .def_readwrite("repeat_penalty", &fastllm::GenerationConfig::repeat_penalty) 
	  .def_readwrite("top_k", &fastllm::GenerationConfig::top_k) 
	  .def_readwrite("top_p", &fastllm::GenerationConfig::top_p) 
	  .def_readwrite("temperature", &fastllm::GenerationConfig::temperature)
	  .def_readwrite("enable_hash_id", &fastllm::GenerationConfig::enable_hash_id)
	  .def("is_simple_greedy", &fastllm::GenerationConfig::IsSimpleGreedy); 

  // high level
  m.def("set_threads", &fastllm::SetThreads)
    .def("get_threads", &fastllm::GetThreads)
    .def("set_low_memory", &fastllm::SetLowMemMode)
    .def("get_low_memory", &fastllm::GetLowMemMode)
    .def("set_kv_cache", &fastllm::SetKVCacheInCPU)
    .def("get_kv_cache", &fastllm::GetKVCacheInCPU)
    .def("set_device_map", &fastllm::SetDeviceMap)
    .def("create_llm", &fastllm::CreateLLMModelFromFile);
  m.def("std_hash", [](std::string input) -> size_t {
		return std::hash<std::string>{}(input);
  }); 
  // low level
  m.def("get_llm_type", &fastllm::GetModelTypeFromFile);

149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
  m.def("llm_sampling", &fastllm::LLMSampling)
    // .def("embedding", &fastllm::Embedding)
    .def("rms_norm", &pyfastllm::RMSNorm)
    .def("layer_norm", &pyfastllm::LayerNorm)
    .def("linear", &pyfastllm::Linear)
    // .def("split", &fastllm::Split)
    // .def("cat", &fastllm::Cat)
    // .def("cat_direct", &fastllm::CatDirect)
    .def("matmul", &pyfastllm::MatMul)
    // .def("matmul_transB", &fastllm::MatMulTransB)
    .def("softmax", &pyfastllm::Softmax)
    .def("silu", &pyfastllm::Silu)
    .def("gelu", &pyfastllm::Gelu)
    .def("swiglu", &pyfastllm::Swiglu)
    .def("mul", &pyfastllm::Mul)
    .def("attention", &pyfastllm::Attention);
    // .def("mul_to", &fastllm::MulTo)
    // .def("add_to", &fastllm::AddTo)
    // .def("attention_mask", &fastllm::AttentionMask)
    // .def("alibi_mask", &fastllm::AlibiMask)
    // .def("permute", &fastllm::Permute)
    // .def("permute_self", &fastllm::PermuteSelf)
    // .def("topk", &fastllm::TopK)
    // .def("rotateposition2D", &fastllm::RotatePosition2D)
    // .def("nearlyrotateposition2D", &fastllm::NearlyRotatePosition2D)
    // .def("llama_rotateposition2D", &fastllm::LlamaRotatePosition2D)
    // .def("repeat_penalty", &fastllm::RepeatPenalty);

zhouxiang's avatar
zhouxiang committed
177
178
179
180
181
182
183
184
185
186
187
188
  py::enum_<fastllm::DataType>(m, "Dtype")
    .value("float32", fastllm::DataType::FLOAT32)
    .value("bfloat16", fastllm::DataType::BFLOAT16)
    .value("int16", fastllm::DataType::INT16)
    .value("int8", fastllm::DataType::INT8)
    .value("int4", fastllm::DataType::INT4)
    .value("int2", fastllm::DataType::INT2)
    .value("float16", fastllm::DataType::FLOAT16)
    .value("bit", fastllm::DataType::BIT)
    .value("int32param", fastllm::DataType::INT32PARAM)
    .export_values();

189
190
191
192
193
194
195
196
197
198
199
200
  py::class_<fastllm::Data>(m, "Tensor", py::buffer_protocol())
    .def_buffer([](fastllm::Data &m) -> py::buffer_info {
        return py::buffer_info(
            m.cpuData,                               /* Pointer to buffer */
            sizeof(float),                          /* Size of one scalar */
            py::format_descriptor<float>::format(), /* Python struct-style format descriptor */
            m.dims.size(),                                      /* Number of dimensions */
            m.dims,                 /* Buffer dimensions */
            { sizeof(float) * m.dims[1],             /* Strides (in bytes) for each index */
              sizeof(float) }
        );
     })
zhouxiang's avatar
zhouxiang committed
201
202
203
204
205
206
    .def_readonly("dims", &fastllm::Data::dims)
    .def(py::init<>())
    .def(py::init<fastllm::DataType>())
    .def(py::init<fastllm::DataType, const std::vector<int>&>())
    .def(py::init<fastllm::DataType, const std::vector<int>&, const std::vector<float>&>())
    .def(py::init<fastllm::Data>())
207
    .def_readonly("shape", &fastllm::Data::dims) 
zhouxiang's avatar
zhouxiang committed
208
209
210
211
212
213
214
215
216
    .def("copy_from", &fastllm::Data::CopyFrom)
    .def("count", &fastllm::Data::Count)
    .def("to_list", [](fastllm::Data& data){
      std::vector <float> vecData;
      for (int i = 0; i < data.Count(0); i++) {
            vecData.push_back(((float*)data.cpuData)[i]);
        }
        return vecData;
    })
217
    .def("__str__", &pyfastllm::String)
zhouxiang's avatar
zhouxiang committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    .def("print", &fastllm::Data::Print)
    .def("to", static_cast<void (fastllm::Data::*)(void *device)>(&fastllm::Data::ToDevice));

  m.def("zeros", [](const std::vector<int> &dims, fastllm::DataType dtype)->fastllm::Data {
    int nums = 1;
    for (auto dim:dims){nums *= dim; } 
    std::vector<float>zero_data(nums, 0);
    auto data = fastllm::Data(dtype, dims, zero_data);
    return data;
  }, py::arg("dims"), py::arg("dtype"));

  m.def("cat", [](std::vector<fastllm::Data> datas, int dim)->fastllm::Data {
    // int pos_dim = 0;
    // // dim check
    // for (int i=0;i<datas[0].dims.size();i++){
    //   int cur_dim = datas[0].dims[i];
    //   for (auto data:datas){
    //     if (i == dim){
    //       pos_dim += data.dims[i];
    //       continue;
    //     }
    //     if (data.dims[i] != cur_dim){
    //       std::cout<<"dim not the same!!!"<<std::endl;
    //       return fastllm::Data();
    //     }
    //   }
    // }

    // auto newDims = datas[0].dims;
    // newDims[dim] = pos_dim;
    // TODO use memcpy cp data 
    // TODO add different dim cat

     std::vector <float> vecData;
     for (auto data:datas){
      for (int i = 0; i < data.Count(0); i++) {
            vecData.push_back(((float*)data.cpuData)[i]);
        }
     }
     int seqLen = vecData.size();
     return fastllm::Data(fastllm::DataType::FLOAT32, {1, seqLen}, vecData);
  });


  py::class_<fastllm::Tokenizer>(m, "Tokenizer")
    .def("encode", &fastllm::Tokenizer::Encode)
    // .def("decode", &fastllm::Tokenizer::Decode)
    .def("decode", &fastllm::Tokenizer::Decode, "Decode from Tensor")
    .def("decode", &fastllm::Tokenizer::DecodeTokens, "Decode from Vector")
    .def("decode_byte", [](fastllm::Tokenizer &tokenizer, const fastllm::Data &data){
      std::string ret = tokenizer.Decode(data);
      return py::bytes(ret);
    })
    .def("decode_byte", [](fastllm::Tokenizer &tokenizer, const std::vector<int>& data){
      std::string ret = tokenizer.DecodeTokens(data);
      return py::bytes(ret);
    })
    .def("clear", &fastllm::Tokenizer::Clear)
    .def("insert", &fastllm::Tokenizer::Insert);
  
  py::class_<fastllm::WeightMap>(m, "WeightMap")
    .def_readonly("tokenizer", &fastllm::WeightMap::tokenizer)
    .def("save_lowbit", &fastllm::WeightMap::SaveLowBitModel)
    .def("set_kv", &fastllm::WeightMap::AddDict)
    .def("set_weight", &fastllm::WeightMap::AddWeight)
    .def("__getitem__", [](fastllm::WeightMap &weight, std::string key){
        return weight[key]; });


  // model classes
  py::class_<fastllm::basellm>(m, "basellm");

  py::class_<fastllm::ChatGLMModel, fastllm::basellm>(m, "ChatGLMModel")
    .def(py::init<>())
    .def_readonly("model_type", &fastllm::ChatGLMModel::model_type)
    .def_readonly("weight", &fastllm::ChatGLMModel::weight)
    .def_readonly("block_cnt", &fastllm::ChatGLMModel::block_cnt)
    .def_readonly("bos_token_id", &fastllm::ChatGLMModel::bos_token_id)
    .def_readonly("eos_token_id", &fastllm::ChatGLMModel::eos_token_id)
297
    .def_readonly("gmask_token_id", &fastllm::ChatGLMModel::gmask_token_id)
zhouxiang's avatar
zhouxiang committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    .def("load_weights", &fastllm::ChatGLMModel::LoadFromFile)
    .def("make_input", &fastllm::ChatGLMModel::MakeInput)
    .def("make_history", &fastllm::ChatGLMModel::MakeHistory)
    .def("response", &fastllm::ChatGLMModel::Response)
    .def("batch_response", [](fastllm::ChatGLMModel &model, 
                              const std::vector <std::string> &inputs,
                               RuntimeResultBatch retCb,
							   fastllm::GenerationConfig config)->std::vector<std::string> {
      std::vector <std::string> outputs;
      model.ResponseBatch(inputs, outputs, retCb, config);
      return outputs;
    })
    .def("warmup", &fastllm::ChatGLMModel::WarmUp)
    .def("forward",
        [](fastllm::ChatGLMModel &model, 
           const fastllm::Data &inputIds, 
           const fastllm::Data &attentionMask,
           const fastllm::Data &positionIds, std::vector<std::pair<fastllm::Data, fastllm::Data>> &pastKeyValues,
           const fastllm::GenerationConfig &generationConfig, const fastllm::LastTokensManager &tokens) {

          int retV = model.Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens);
          return std::make_tuple(retV, pastKeyValues);
    })
    .def("launch_response", &fastllm::ChatGLMModel::LaunchResponseTokens)
    .def("fetch_response", &fastllm::ChatGLMModel::FetchResponseTokens)
    .def("save_lowbit_model", &fastllm::ChatGLMModel::SaveLowBitModel)
    .def("make_input", &fastllm::ChatGLMModel::MakeInput);

  py::class_<fastllm::MOSSModel, fastllm::basellm>(m, "MOSSModel")
    .def(py::init<>())
    .def_readonly("model_type", &fastllm::MOSSModel::model_type)
    .def_readonly("weight", &fastllm::MOSSModel::weight)
    .def_readonly("block_cnt", &fastllm::MOSSModel::block_cnt)
    .def_readonly("bos_token_id", &fastllm::MOSSModel::bos_token_id)
    .def_readonly("eos_token_id", &fastllm::MOSSModel::eos_token_id)
    .def("load_weights", &fastllm::MOSSModel::LoadFromFile)
    .def("make_input", &fastllm::MOSSModel::MakeInput)
    .def("make_history", &fastllm::MOSSModel::MakeHistory)
    .def("response", &fastllm::MOSSModel::Response)
    .def("batch_response", [](fastllm::MOSSModel &model, 
                              const std::vector <std::string> &inputs,
                               RuntimeResultBatch retCb,
							   fastllm::GenerationConfig config)->std::vector<std::string> {
      std::vector <std::string> outputs;
      model.ResponseBatch(inputs, outputs, retCb, config);
      return outputs;
    })
    .def("forward",
        [](fastllm::MOSSModel &model, 
           const fastllm::Data &inputIds, 
           const fastllm::Data &attentionMask,
           const fastllm::Data &positionIds, std::vector<std::pair<fastllm::Data, fastllm::Data>> &pastKeyValues,
           const fastllm::GenerationConfig &generationConfig, const fastllm::LastTokensManager &tokens) {
          int retV = model.Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens);
          return std::make_tuple(retV, pastKeyValues);
    })
    .def("launch_response", &fastllm::MOSSModel::LaunchResponseTokens)
    .def("fetch_response", &fastllm::MOSSModel::FetchResponseTokens)
    .def("save_lowbit_model", &fastllm::MOSSModel::SaveLowBitModel)
    .def("make_input", &fastllm::MOSSModel::MakeInput);

  py::class_<fastllm::LlamaModel, fastllm::basellm>(m, "LlamaModel")
    .def(py::init<>())
    .def_readonly("model_type", &fastllm::LlamaModel::model_type)
    .def_readonly("weight", &fastllm::LlamaModel::weight)
    .def_readonly("block_cnt", &fastllm::LlamaModel::block_cnt)
    .def_readonly("bos_token_id", &fastllm::LlamaModel::bos_token_id)
    .def_readonly("eos_token_id", &fastllm::LlamaModel::eos_token_id)
    .def("load_weights", &fastllm::LlamaModel::LoadFromFile)
    .def("make_input", &fastllm::LlamaModel::MakeInput)
    .def("make_history", &fastllm::LlamaModel::MakeHistory)
    .def("response", &fastllm::LlamaModel::Response)
    .def("batch_response", [](fastllm::LlamaModel &model, 
                              const std::vector <std::string> &inputs,
                               RuntimeResultBatch retCb,
							   fastllm::GenerationConfig config)->std::vector<std::string> {
      std::vector <std::string> outputs;
      model.ResponseBatch(inputs, outputs, retCb, config);
      return outputs;
    })
    .def("warmup", &fastllm::LlamaModel::WarmUp)
    .def("forward",
        [](fastllm::LlamaModel &model, 
           const fastllm::Data &inputIds, 
           const fastllm::Data &attentionMask,
           const fastllm::Data &positionIds, std::vector<std::pair<fastllm::Data, fastllm::Data>> &pastKeyValues,
           const fastllm::GenerationConfig &generationConfig, const fastllm::LastTokensManager &tokens) {
          int retV = model.Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens);
          return std::make_tuple(retV, pastKeyValues);
    })
    .def("launch_response", &fastllm::LlamaModel::LaunchResponseTokens)
    .def("fetch_response", &fastllm::LlamaModel::FetchResponseTokens)
    .def("save_lowbit_model", &fastllm::LlamaModel::SaveLowBitModel)
    .def("make_input", &fastllm::LlamaModel::MakeInput);

  py::class_<fastllm::QWenModel, fastllm::basellm>(m, "QWenModel")
    .def(py::init<>())
    .def_readonly("model_type", &fastllm::QWenModel::model_type)
    .def_readonly("weight", &fastllm::QWenModel::weight)
    .def_readonly("block_cnt", &fastllm::QWenModel::block_cnt)
    .def_readonly("bos_token_id", &fastllm::QWenModel::bos_token_id)
    .def_readonly("eos_token_id", &fastllm::QWenModel::eos_token_id)
    .def("load_weights", &fastllm::QWenModel::LoadFromFile)
    .def("make_input", &fastllm::QWenModel::MakeInput)
    .def("make_history", &fastllm::QWenModel::MakeHistory)
    .def("response", &fastllm::QWenModel::Response)
    .def("batch_response", [](fastllm::QWenModel &model, 
                                const std::vector <std::string> &inputs,
                                RuntimeResultBatch retCb,
                                fastllm::GenerationConfig config)->std::vector<std::string> {
        std::vector <std::string> outputs;
        model.ResponseBatch(inputs, outputs, retCb, config);
        return outputs;
    })
    .def("warmup", &fastllm::QWenModel::WarmUp)
    .def("forward",
        [](fastllm::QWenModel &model, 
            const fastllm::Data &inputIds, 
            const fastllm::Data &attentionMask,
            const fastllm::Data &positionIds, std::vector<std::pair<fastllm::Data, fastllm::Data>> &pastKeyValues,
            const fastllm::GenerationConfig &generationConfig, const fastllm::LastTokensManager &tokens) {

            int retV = model.Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens);
            return std::make_tuple(retV, pastKeyValues);
    })
    .def("launch_response", &fastllm::QWenModel::LaunchResponseTokens)
    .def("fetch_response", &fastllm::QWenModel::FetchResponseTokens)
    .def("save_lowbit_model", &fastllm::QWenModel::SaveLowBitModel)
    .def("make_input", &fastllm::QWenModel::MakeInput);

#ifdef VERSION_INFO
    m.attr("__version__") = VERSION_INFO;
#else
    m.attr("__version__") = "dev";
#endif

}

#endif