"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "a3862f151f8fa34154095217389d8661c015f662"
c_api.cpp 27.2 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
#include <omp.h>

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
7
8
9
10
11
12
13
14
15
16
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
17
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
18

Guolin Ke's avatar
Guolin Ke committed
19
20
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
25
namespace LightGBM {

class Booster {
public:
  explicit Booster(const char* filename):
26
27
28
    boosting_(Boosting::CreateBoosting(filename)), 
    objective_fun_(nullptr), 
    predictor_(nullptr) {
Guolin Ke's avatar
Guolin Ke committed
29
30
31
32
33
34
  }

  Booster(const Dataset* train_data, 
    std::vector<const Dataset*> valid_data, 
    std::vector<std::string> valid_names,
    const char* parameters)
Guolin Ke's avatar
Guolin Ke committed
35
    :train_data_(train_data), valid_datas_(valid_data), predictor_(nullptr) {
Guolin Ke's avatar
Guolin Ke committed
36
37
38
    config_.LoadFromString(parameters);
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
39
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
40
41
42
43
44
45
46
47
        please use continued train with input score");
    }
    boosting_ = Boosting::CreateBoosting(config_.boosting_type, "");
    // create objective function
    objective_fun_ =
      ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
        config_.objective_config);
    // create training metric
Guolin Ke's avatar
Guolin Ke committed
48
49
50
51
52
53
54
    for (auto metric_type : config_.metric_types) {
      Metric* metric =
        Metric::CreateMetric(metric_type, config_.metric_config);
      if (metric == nullptr) { continue; }
      metric->Init("training", train_data_->metadata(),
        train_data_->num_data());
      train_metric_.push_back(metric);
Guolin Ke's avatar
Guolin Ke committed
55
    }
Guolin Ke's avatar
Guolin Ke committed
56
    
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    // add metric for validation data
    for (size_t i = 0; i < valid_datas_.size(); ++i) {
      valid_metrics_.emplace_back();
      for (auto metric_type : config_.metric_types) {
        Metric* metric = Metric::CreateMetric(metric_type, config_.metric_config);
        if (metric == nullptr) { continue; }
        metric->Init(valid_names[i].c_str(),
          valid_datas_[i]->metadata(),
          valid_datas_[i]->num_data());
        valid_metrics_.back().push_back(metric);
      }
    }
    // initialize the objective function
    objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    // initialize the boosting
    boosting_->Init(config_.boosting_config, train_data_, objective_fun_,
      ConstPtrInVectorWarpper<Metric>(train_metric_));
    // add validation data into boosting
    for (size_t i = 0; i < valid_datas_.size(); ++i) {
      boosting_->AddDataset(valid_datas_[i],
        ConstPtrInVectorWarpper<Metric>(valid_metrics_[i]));
    }
  }

  ~Booster() {
    for (auto& metric : train_metric_) {
      if (metric != nullptr) { delete metric; }
    }
    for (auto& metric : valid_metrics_) {
      for (auto& sub_metric : metric) {
        if (sub_metric != nullptr) { delete sub_metric; }
      }
    }
    valid_metrics_.clear();
    if (boosting_ != nullptr) { delete boosting_; }
    if (objective_fun_ != nullptr) { delete objective_fun_; }
Guolin Ke's avatar
Guolin Ke committed
93
    if (predictor_ != nullptr) { delete predictor_; }
Guolin Ke's avatar
Guolin Ke committed
94
  }
95
96
97
98
99
100
101
102
103
104
105

  bool TrainOneIter() {
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

  void PrepareForPrediction(int num_used_model, int predict_type) {
    boosting_->SetNumUsedModel(num_used_model);
Guolin Ke's avatar
Guolin Ke committed
106
107
108
    if (predictor_ != nullptr) { delete predictor_; }
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
109
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
110
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
111
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
112
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
113
114
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
    }
    predictor_ = new Predictor(boosting_, is_raw_score, is_predict_leaf);
  }

  std::vector<double> Predict(const std::vector<std::pair<int, double>>& features) {
    return predictor_->GetPredictFunction()(features);
121
122
  }

123
124
125
126
  void PredictForFile(const char* data_filename, const char* result_filename, bool data_has_header) {
    predictor_->Predict(data_filename, result_filename, data_has_header);
  }

Guolin Ke's avatar
Guolin Ke committed
127
128
129
  void SaveModelToFile(int num_used_model, const char* filename) {
    boosting_->SaveModelToFile(num_used_model, true, filename);
  }
130
  
131
132
  const Boosting* GetBoosting() const { return boosting_; }

133
134
  const float* GetTrainingScore(int* out_len) const { return boosting_->GetTrainingScore(out_len); }

Guolin Ke's avatar
Guolin Ke committed
135
  const inline int NumberOfClasses() const { return boosting_->NumberOfClasses(); }
136

Guolin Ke's avatar
Guolin Ke committed
137
private:
138

Guolin Ke's avatar
Guolin Ke committed
139
140
141
142
143
144
145
146
147
148
149
150
151
  Boosting* boosting_;
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Training data */
  const Dataset* train_data_;
  /*! \brief Validation data */
  std::vector<const Dataset*> valid_datas_;
  /*! \brief Metric for training data */
  std::vector<Metric*> train_metric_;
  /*! \brief Metrics for validation data */
  std::vector<std::vector<Metric*>> valid_metrics_;
  /*! \brief Training objective function */
  ObjectiveFunction* objective_fun_;
Guolin Ke's avatar
Guolin Ke committed
152
153
  /*! \brief Using predictor for prediction task */
  Predictor* predictor_;
154

Guolin Ke's avatar
Guolin Ke committed
155
156
157
};

}
Guolin Ke's avatar
Guolin Ke committed
158
159
160

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
161
DllExport const char* LGBM_GetLastError() {
162
  return LastErrorMsg().c_str();
Guolin Ke's avatar
Guolin Ke committed
163
164
165
166
167
168
}

DllExport int LGBM_CreateDatasetFromFile(const char* filename,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
169
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
170
171
172
  OverallConfig config;
  config.LoadFromString(parameters);
  DatasetLoader loader(config.io_config, nullptr);
173
  loader.SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
174
175
176
  if (reference == nullptr) {
    *out = loader.LoadFromFile(filename);
  } else {
Guolin Ke's avatar
Guolin Ke committed
177
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, reinterpret_cast<const Dataset*>(*reference));
Guolin Ke's avatar
Guolin Ke committed
178
  }
179
  API_END();
Guolin Ke's avatar
Guolin Ke committed
180
181
182
183
}

DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
  DatesetHandle* out) {
184
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
185
186
187
  OverallConfig config;
  DatasetLoader loader(config.io_config, nullptr);
  *out = loader.LoadFromBinFile(filename, 0, 1);
188
  API_END();
Guolin Ke's avatar
Guolin Ke committed
189
190
191
}

DllExport int LGBM_CreateDatasetFromMat(const void* data,
192
  int data_type,
Guolin Ke's avatar
Guolin Ke committed
193
194
195
196
197
198
  int32_t nrow,
  int32_t ncol,
  int is_row_major,
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
199
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
  OverallConfig config;
  config.LoadFromString(parameters);
  DatasetLoader loader(config.io_config, nullptr);
  Dataset* ret = nullptr;
204
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
205
206
207
208
209
  if (reference == nullptr) {
    // sample data first
    Random rand(config.io_config.data_random_seed);
    const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
210
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
211
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
212
      auto idx = sample_indices[i];
213
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
214
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
215
216
217
        if (std::fabs(row[j]) > 1e-15) {
          sample_values[j].push_back(row[j]);
        }
Guolin Ke's avatar
Guolin Ke committed
218
219
      }
    }
Guolin Ke's avatar
Guolin Ke committed
220
    ret = loader.CostructFromSampleData(sample_values, sample_cnt, nrow);
Guolin Ke's avatar
Guolin Ke committed
221
  } else {
222
    ret = new Dataset(nrow, config.io_config.num_class);
Guolin Ke's avatar
Guolin Ke committed
223
    ret->CopyFeatureMapperFrom(reinterpret_cast<const Dataset*>(*reference), config.io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
224
225
226
227
228
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    const int tid = omp_get_thread_num();
229
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
230
231
232
233
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
  *out = ret;
234
  API_END();
235
236
}

237
238
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
  int indptr_type,
239
240
  const int32_t* indices,
  const void* data,
241
242
243
244
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t num_col,
245
246
247
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
248
  API_BEGIN();
249
250
251
252
  OverallConfig config;
  config.LoadFromString(parameters);
  DatasetLoader loader(config.io_config, nullptr);
  Dataset* ret = nullptr;
253
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
254
255
256
257
258
259
260
261
262
263
264
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
    Random rand(config.io_config.data_random_seed);
    const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
265
266
267
268
269
270
271
        if (std::fabs(inner_data.second) > 1e-15) {
          if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
            // if need expand feature set
            size_t need_size = inner_data.first - sample_values.size() + 1;
            for (size_t j = 0; j < need_size; ++j) {
              sample_values.emplace_back();
            }
272
          }
Guolin Ke's avatar
Guolin Ke committed
273
274
          // edit the feature value
          sample_values[inner_data.first].push_back(inner_data.second);
275
276
277
        }
      }
    }
278
    CHECK(num_col >= static_cast<int>(sample_values.size()));
Guolin Ke's avatar
Guolin Ke committed
279
    ret = loader.CostructFromSampleData(sample_values, sample_cnt, nrow);
280
  } else {
281
    ret = new Dataset(nrow, config.io_config.num_class);
Guolin Ke's avatar
Guolin Ke committed
282
    ret->CopyFeatureMapperFrom(reinterpret_cast<const Dataset*>(*reference), config.io_config.is_enable_sparse);
283
284
285
286
287
288
289
290
291
292
293
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nindptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
  }
  ret->FinishLoad();
  *out = ret;

294
  API_END();
295
296
}

297
298
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
  int col_ptr_type,
Guolin Ke's avatar
Guolin Ke committed
299
300
  const int32_t* indices,
  const void* data,
301
302
303
304
  int data_type,
  int64_t ncol_ptr,
  int64_t nelem,
  int64_t num_row,
Guolin Ke's avatar
Guolin Ke committed
305
306
307
  const char* parameters,
  const DatesetHandle* reference,
  DatesetHandle* out) {
308
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
  OverallConfig config;
  config.LoadFromString(parameters);
  DatasetLoader loader(config.io_config, nullptr);
  Dataset* ret = nullptr;
313
  auto get_col_fun = ColumnFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318
319
320
321
322
323
324
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    Log::Warning("Construct from CSC format is not efficient");
    // sample data first
    Random rand(config.io_config.data_random_seed);
    const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
      auto cur_col = get_col_fun(i);
325
      sample_values[i] = SampleFromOneColumn(cur_col, sample_indices);
Guolin Ke's avatar
Guolin Ke committed
326
    }
Guolin Ke's avatar
Guolin Ke committed
327
    ret = loader.CostructFromSampleData(sample_values, sample_cnt, nrow);
Guolin Ke's avatar
Guolin Ke committed
328
  } else {
329
    ret = new Dataset(nrow, config.io_config.num_class);
Guolin Ke's avatar
Guolin Ke committed
330
    ret->CopyFeatureMapperFrom(reinterpret_cast<const Dataset*>(*reference), config.io_config.is_enable_sparse);
Guolin Ke's avatar
Guolin Ke committed
331
332
333
334
335
336
  }

#pragma omp parallel for schedule(guided)
  for (int i = 0; i < ncol_ptr - 1; ++i) {
    const int tid = omp_get_thread_num();
    auto one_col = get_col_fun(i);
Guolin Ke's avatar
Guolin Ke committed
337
    ret->PushOneColumn(tid, i, one_col);
Guolin Ke's avatar
Guolin Ke committed
338
339
340
  }
  ret->FinishLoad();
  *out = ret;
341
  API_END();
Guolin Ke's avatar
Guolin Ke committed
342
343
}

344
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
345
  API_BEGIN();
346
  auto dataset = reinterpret_cast<Dataset*>(handle);
347
  delete dataset;
348
  API_END();
349
350
351
352
}

DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
  const char* filename) {
353
  API_BEGIN();
354
355
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
356
  API_END();
357
358
359
360
361
}

DllExport int LGBM_DatasetSetField(DatesetHandle handle,
  const char* field_name,
  const void* field_data,
362
  int64_t num_element,
363
  int type) {
364
  API_BEGIN();
365
  auto dataset = reinterpret_cast<Dataset*>(handle);
366
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
367
  if (type == C_API_DTYPE_FLOAT32) {
368
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
369
  } else if (type == C_API_DTYPE_INT32) {
370
371
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
  }
372
373
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
374
375
376
377
}

DllExport int LGBM_DatasetGetField(DatesetHandle handle,
  const char* field_name,
378
  int64_t* out_len,
379
380
  const void** out_ptr,
  int* out_type) {
381
  API_BEGIN();
382
  auto dataset = reinterpret_cast<Dataset*>(handle);
383
  bool is_success = false;
384
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
385
    *out_type = C_API_DTYPE_FLOAT32;
386
    is_success = true;
387
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
388
    *out_type = C_API_DTYPE_INT32;
389
    is_success = true;
390
  }
391
392
  if (!is_success) { throw std::runtime_error("Field not found"); }
  API_END();
393
394
395
}

DllExport int LGBM_DatasetGetNumData(DatesetHandle handle,
396
  int64_t* out) {
397
  API_BEGIN();
398
399
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
400
  API_END();
401
402
403
}

DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
404
  int64_t* out) {
405
  API_BEGIN();
406
407
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
408
  API_END();
Guolin Ke's avatar
Guolin Ke committed
409
}
410
411
412
413
414
415
416
417
418
419


// ---- start of booster

DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
  const DatesetHandle valid_datas[],
  const char* valid_names[],
  int n_valid_datas,
  const char* parameters,
  BoosterHandle* out) {
420
  API_BEGIN();
421
422
423
424
425
426
427
428
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
  std::vector<const Dataset*> p_valid_datas;
  std::vector<std::string> p_valid_names;
  for (int i = 0; i < n_valid_datas; ++i) {
    p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
    p_valid_names.emplace_back(valid_names[i]);
  }
  *out = new Booster(p_train_data, p_valid_datas, p_valid_names, parameters);
429
  API_END();
430
431
432
433
434
}

DllExport int LGBM_BoosterLoadFromModelfile(
  const char* filename,
  BoosterHandle* out) {
435
  API_BEGIN();
436
  *out = new Booster(filename);
437
  API_END();
438
439
440
}

DllExport int LGBM_BoosterFree(BoosterHandle handle) {
441
  API_BEGIN();
442
443
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  delete ref_booster;
444
  API_END();
445
446
447
}

DllExport int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
448
  API_BEGIN();
449
450
451
452
453
454
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
455
  API_END();
456
457
458
459
460
461
}

DllExport int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
  const float* grad,
  const float* hess,
  int* is_finished) {
462
  API_BEGIN();
463
464
465
466
467
468
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
469
  API_END();
470
471
472
473
}

DllExport int LGBM_BoosterEval(BoosterHandle handle,
  int data,
474
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
475
  float* out_results) {
476
  API_BEGIN();
477
478
479
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto boosting = ref_booster->GetBoosting();
  auto result_buf = boosting->GetEvalAt(data);
480
  *out_len = static_cast<int64_t>(result_buf.size());
481
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
482
    (out_results)[i] = static_cast<float>(result_buf[i]);
483
  }
484
  API_END();
485
486
487
}

DllExport int LGBM_BoosterGetScore(BoosterHandle handle,
488
  int64_t* out_len,
489
  const float** out_result) {
490
  API_BEGIN();
491
492
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  int len = 0;
493
  *out_result = ref_booster->GetTrainingScore(&len);
494
  *out_len = static_cast<int64_t>(len);
495
  API_END();
496
497
}

Guolin Ke's avatar
Guolin Ke committed
498
499
DllExport int LGBM_BoosterGetPredict(BoosterHandle handle,
  int data,
500
  int64_t* out_len,
Guolin Ke's avatar
Guolin Ke committed
501
  float* out_result) {
502
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
503
504
505
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto boosting = ref_booster->GetBoosting();
  int len = 0;
Guolin Ke's avatar
Guolin Ke committed
506
  boosting->GetPredictAt(data, out_result, &len);
507
  *out_len = static_cast<int64_t>(len);
508
  API_END();
Guolin Ke's avatar
Guolin Ke committed
509
510
}

511
512
513
514
515
516
DllExport int LGBM_BoosterPredictForFile(BoosterHandle handle,
  int predict_type,
  int64_t n_used_trees,
  int data_has_header,
  const char* data_filename,
  const char* result_filename) {
517
  API_BEGIN();
518
519
520
521
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);
  bool bool_data_has_header = data_has_header > 0 ? true : false;
  ref_booster->PredictForFile(data_filename, result_filename, bool_data_has_header);
522
  API_END();
523
524
}

525
DllExport int LGBM_BoosterPredictForCSR(BoosterHandle handle,
526
527
  const void* indptr,
  int indptr_type,
528
529
  const int32_t* indices,
  const void* data,
530
531
532
533
  int data_type,
  int64_t nindptr,
  int64_t nelem,
  int64_t,
534
  int predict_type,
535
  int64_t n_used_trees,
Guolin Ke's avatar
Guolin Ke committed
536
  double* out_result) {
537
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
538
539
540
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);

541
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
542
  int num_class = ref_booster->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
543
544
545
546
547
  int nrow = static_cast<int>(nindptr - 1);
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
548
    for (int j = 0; j < num_class; ++j) {
Guolin Ke's avatar
Guolin Ke committed
549
550
551
      out_result[i * num_class + j] = predicton_result[j];
    }
  }
552
  API_END();
Guolin Ke's avatar
Guolin Ke committed
553
}
554
555
556

DllExport int LGBM_BoosterPredictForMat(BoosterHandle handle,
  const void* data,
557
  int data_type,
558
559
  int32_t nrow,
  int32_t ncol,
Guolin Ke's avatar
Guolin Ke committed
560
  int is_row_major,
561
  int predict_type,
562
  int64_t n_used_trees,
Guolin Ke's avatar
Guolin Ke committed
563
  double* out_result) {
564
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
565
566
567
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->PrepareForPrediction(static_cast<int>(n_used_trees), predict_type);

568
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
569
  int num_class = ref_booster->NumberOfClasses();
Guolin Ke's avatar
Guolin Ke committed
570
571
572
573
#pragma omp parallel for schedule(guided)
  for (int i = 0; i < nrow; ++i) {
    auto one_row = get_row_fun(i);
    auto predicton_result = ref_booster->Predict(one_row);
Guolin Ke's avatar
Guolin Ke committed
574
    for (int j = 0; j < num_class; ++j) {
Guolin Ke's avatar
Guolin Ke committed
575
576
577
      out_result[i * num_class + j] = predicton_result[j];
    }
  }
578
  API_END();
Guolin Ke's avatar
Guolin Ke committed
579
}
580
581
582

DllExport int LGBM_BoosterSaveModel(BoosterHandle handle,
  int num_used_model,
Guolin Ke's avatar
Guolin Ke committed
583
  const char* filename) {
584
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
585
586
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->SaveModelToFile(num_used_model, filename);
587
  API_END();
Guolin Ke's avatar
Guolin Ke committed
588
}
589

Guolin Ke's avatar
Guolin Ke committed
590
// ---- start of some help functions
591
592
593

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
594
  if (data_type == C_API_DTYPE_FLOAT32) {
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
        CHECK(row_idx < num_row);
        std::vector<double> ret;
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
          ret.push_back(static_cast<double>(*(tmp_ptr + i)));
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
        CHECK(row_idx < num_row);
        std::vector<double> ret;
        for (int i = 0; i < num_col; ++i) {
          ret.push_back(static_cast<double>(*(data_ptr + num_row * i + row_idx)));
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
616
  } else if (data_type == C_API_DTYPE_FLOAT64) {
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
      return [data_ptr, num_col, num_row](int row_idx) {
        CHECK(row_idx < num_row);
        std::vector<double> ret;
        auto tmp_ptr = data_ptr + num_col * row_idx;
        for (int i = 0; i < num_col; ++i) {
          ret.push_back(static_cast<double>(*(tmp_ptr + i)));
        }
        return ret;
      };
    } else {
      return [data_ptr, num_col, num_row](int row_idx) {
        CHECK(row_idx < num_row);
        std::vector<double> ret;
        for (int i = 0; i < num_col; ++i) {
          ret.push_back(static_cast<double>(*(data_ptr + num_row * i + row_idx)));
        }
        return ret;
      };
    }
  } else {
    Log::Fatal("unknown data type in RowFunctionFromDenseMatric");
  }
Guolin Ke's avatar
Guolin Ke committed
641
  return nullptr;
642
643
644
645
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
646
647
648
649
650
651
652
653
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
    return [inner_function](int row_idx) {
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
654
        }
Guolin Ke's avatar
Guolin Ke committed
655
656
657
      }
      return ret;
    };
658
  }
Guolin Ke's avatar
Guolin Ke committed
659
  return nullptr;
660
661
662
663
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
664
  if (data_type == C_API_DTYPE_FLOAT32) {
665
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
666
    if (indptr_type == C_API_DTYPE_INT32) {
667
668
669
670
671
672
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        CHECK(idx + 1 < nindptr);
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
673
        CHECK(start >= 0 && end <= nelem);
674
675
676
677
678
        for (int64_t i = start; i <= end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
679
    } else if (indptr_type == C_API_DTYPE_INT64) {
680
681
682
683
684
685
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        CHECK(idx + 1 < nindptr);
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
686
        CHECK(start >= 0 && end <= nelem);
687
688
689
690
691
692
693
694
        for (int64_t i = start; i <= end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    } else {
      Log::Fatal("unknown data type in RowFunctionFromCSR");
    }
Guolin Ke's avatar
Guolin Ke committed
695
  } else if (data_type == C_API_DTYPE_FLOAT64) {
696
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
697
    if (indptr_type == C_API_DTYPE_INT32) {
698
699
700
701
702
703
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        CHECK(idx + 1 < nindptr);
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
704
        CHECK(start >= 0 && end <= nelem);
705
706
707
708
709
        for (int64_t i = start; i <= end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
710
    } else if (indptr_type == C_API_DTYPE_INT64) {
711
712
713
714
715
716
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
        CHECK(idx + 1 < nindptr);
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
717
        CHECK(start >= 0 && end <= nelem);
718
719
720
721
722
723
724
725
726
727
728
        for (int64_t i = start; i <= end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    } else {
      Log::Fatal("unknown data type in RowFunctionFromCSR");
    }
  } else {
    Log::Fatal("unknown data type in RowFunctionFromCSR");
  }
Guolin Ke's avatar
Guolin Ke committed
729
  return nullptr;
730
731
732
733
}

std::function<std::vector<std::pair<int, double>>(int idx)>
ColumnFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
734
  if (data_type == C_API_DTYPE_FLOAT32) {
735
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
736
    if (col_ptr_type == C_API_DTYPE_INT32) {
737
738
739
740
741
742
743
744
745
746
747
748
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        CHECK(idx + 1 < ncol_ptr);
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        CHECK(start >= 0 && end <= nelem);
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
749
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        CHECK(idx + 1 < ncol_ptr);
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        CHECK(start >= 0 && end <= nelem);
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    } else {
      Log::Fatal("unknown data type in ColumnFunctionFromCSC");
    }
Guolin Ke's avatar
Guolin Ke committed
765
  } else if (data_type == C_API_DTYPE_FLOAT64) {
766
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
767
    if (col_ptr_type == C_API_DTYPE_INT32) {
768
769
770
771
772
773
774
775
776
777
778
779
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        CHECK(idx + 1 < ncol_ptr);
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        CHECK(start >= 0 && end <= nelem);
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
780
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem](int idx) {
        CHECK(idx + 1 < ncol_ptr);
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_col_ptr[idx];
        int64_t end = ptr_col_ptr[idx + 1];
        CHECK(start >= 0 && end <= nelem);
        for (int64_t i = start; i < end; ++i) {
          ret.emplace_back(indices[i], data_ptr[i]);
        }
        return ret;
      };
    } else {
      Log::Fatal("unknown data type in ColumnFunctionFromCSC");
    }
  } else {
    Log::Fatal("unknown data type in ColumnFunctionFromCSC");
  }
Guolin Ke's avatar
Guolin Ke committed
799
  return nullptr;
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
}

std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<size_t>& indices) {
  size_t j = 0;
  std::vector<double> ret;
  for (auto row_idx : indices) {
    while (j < data.size() && data[j].first < static_cast<int>(row_idx)) {
      ++j;
    }
    if (j < data.size() && data[j].first == static_cast<int>(row_idx)) {
      ret.push_back(data[j].second);
    } else {
      ret.push_back(0);
    }
  }
  return ret;
}