c_api.cpp 55.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
2

Guolin Ke's avatar
Guolin Ke committed
3
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
4
5
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/utils/log.h>
Guolin Ke's avatar
Guolin Ke committed
8
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
9
10
11
12
13
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>
cbecker's avatar
cbecker committed
14
#include <LightGBM/prediction_early_stop.h>
15
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
19
20

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
21
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
22
#include <stdexcept>
wxchan's avatar
wxchan committed
23
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
24
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
28
29
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

Guolin Ke's avatar
Guolin Ke committed
46
class Booster {
Nikita Titov's avatar
Nikita Titov committed
47
 public:
Guolin Ke's avatar
Guolin Ke committed
48
  explicit Booster(const char* filename) {
49
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
50
51
  }

Guolin Ke's avatar
Guolin Ke committed
52
  Booster(const Dataset* train_data,
53
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
54
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
55
    config_.Set(param);
56
57
58
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
59
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
60
    if (config_.input_model.size() > 0) {
61
62
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
63
    }
Guolin Ke's avatar
Guolin Ke committed
64

Guolin Ke's avatar
Guolin Ke committed
65
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
66

67
68
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
69
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
70
    if (config_.tree_learner == std::string("feature")) {
71
      Log::Fatal("Do not support feature parallel in c api");
72
    }
Guolin Ke's avatar
Guolin Ke committed
73
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
74
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
75
      config_.tree_learner = "serial";
76
    }
Guolin Ke's avatar
Guolin Ke committed
77
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
78
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
79
80
81
82
83
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
  }

  ~Booster() {
  }
88

89
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
90
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
91
92
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
97
98
99
100
101
102
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
103
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
104
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
105
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
111
112
113
114
115
116
117
118
119
120
121
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
122
123
124
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
125
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
126
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
127
    if (param.count("num_class")) {
128
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
129
    }
Guolin Ke's avatar
Guolin Ke committed
130
131
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
132
    }
Guolin Ke's avatar
Guolin Ke committed
133
    if (param.count("metric")) {
134
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
135
    }
Guolin Ke's avatar
Guolin Ke committed
136
137

    config_.Set(param);
138
139
140
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
141
142
143

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
144
145
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
151
152
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
153
154
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
155
    }
Guolin Ke's avatar
Guolin Ke committed
156

Guolin Ke's avatar
Guolin Ke committed
157
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
158
159
160
161
162
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
163
164
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
165
166
167
168
169
170
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
171
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
172
  }
Guolin Ke's avatar
Guolin Ke committed
173

174
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
175
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
176
    return boosting_->TrainOneIter(nullptr, nullptr);
177
178
  }

Guolin Ke's avatar
Guolin Ke committed
179
180
181
182
183
184
185
186
187
188
189
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

190
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
191
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
192
    return boosting_->TrainOneIter(gradients, hessians);
193
194
  }

wxchan's avatar
wxchan committed
195
196
197
198
199
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
200
201
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
202
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
203
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
204
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
205
206
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
207
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
208
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
209
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
210
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
211
      is_raw_score = true;
212
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
213
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
214
215
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
216
    }
Guolin Ke's avatar
Guolin Ke committed
217

Guolin Ke's avatar
Guolin Ke committed
218
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
219
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
220
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
221
    auto pred_fun = predictor.GetPredictFunction();
222
223
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
224
    for (int i = 0; i < nrow; ++i) {
225
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
226
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
227
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
228
      pred_fun(one_row, pred_wrt_ptr);
229
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
230
    }
231
    OMP_THROW_EX();
232
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
233
234
235
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
236
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
237
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
238
239
240
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
241
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
242
243
244
245
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
246
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
247
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
248
249
250
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
251
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
252
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
253
254
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
255
256
  }

Guolin Ke's avatar
Guolin Ke committed
257
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
258
259
260
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

261
262
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
263
  }
264

265
  void LoadModelFromString(const char* model_str) {
266
267
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
268
269
  }

270
271
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
272
273
  }

274
  std::string DumpModel(int start_iteration, int num_iteration) {
275
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
276
  }
277

278
279
280
281
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
282
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
283
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
288
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
289
290
  }

291
  void ShuffleModels(int start_iter, int end_iter) {
292
    std::lock_guard<std::mutex> lock(mutex_);
293
    boosting_->ShuffleModels(start_iter, end_iter);
294
295
  }

wxchan's avatar
wxchan committed
296
297
298
299
300
301
302
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
303

wxchan's avatar
wxchan committed
304
305
306
307
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
308
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
309
310
311
312
313
314
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
315
316
317
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
318
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
319
320
321
322
323
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
324
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
325

Nikita Titov's avatar
Nikita Titov committed
326
 private:
wxchan's avatar
wxchan committed
327
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
328
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
329
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
330
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
331
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
332
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
333
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
334
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
335
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
336
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
337
338
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
339
340
};

341
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
342
343
344

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
345
346
347
348
349
350
351
352
353
354
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
355
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
356
357
358

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
359
 public:
Guolin Ke's avatar
Guolin Ke committed
360
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
361
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
362
363
364
365
366
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
367
368

 private:
Guolin Ke's avatar
Guolin Ke committed
369
370
371
372
373
374
375
376
377
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
378
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
379
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
380
381
}

Guolin Ke's avatar
Guolin Ke committed
382
int LGBM_DatasetCreateFromFile(const char* filename,
383
384
385
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
386
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
387
388
  auto param = Config::Str2Map(parameters);
  Config config;
389
390
391
392
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
393
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
394
  if (reference == nullptr) {
395
396
397
398
399
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
400
  } else {
401
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
402
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
403
  }
404
  API_END();
Guolin Ke's avatar
Guolin Ke committed
405
406
}

407

Guolin Ke's avatar
Guolin Ke committed
408
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
409
410
411
412
413
414
415
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
416
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
417
418
  auto param = Config::Str2Map(parameters);
  Config config;
419
420
421
422
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
423
  DatasetLoader loader(config, nullptr, 1, nullptr);
424
425
426
427
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
428
429
}

430

Guolin Ke's avatar
Guolin Ke committed
431
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
432
433
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
434
435
436
437
438
439
440
441
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
442
int LGBM_DatasetPushRows(DatasetHandle dataset,
443
444
445
446
447
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
448
449
450
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
451
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
452
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
453
  for (int i = 0; i < nrow; ++i) {
454
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
455
456
457
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
458
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
459
  }
460
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
461
462
463
464
465
466
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
467
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
468
469
470
471
472
473
474
475
476
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
477
478
479
480
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
481
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
482
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
483
  for (int i = 0; i < nrow; ++i) {
484
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
485
486
487
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
488
                          static_cast<data_size_t>(start_row + i), one_row);
489
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
490
  }
491
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
492
493
494
495
496
497
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
498
int LGBM_DatasetCreateFromMat(const void* data,
499
500
501
502
503
504
505
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
527
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
528
529
  auto param = Config::Str2Map(parameters);
  Config config;
530
531
532
533
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
534
  std::unique_ptr<Dataset> ret;
535
536
537
538
539
540
541
542
543
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
544

Guolin Ke's avatar
Guolin Ke committed
545
546
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
547
    Random rand(config.data_random_seed);
548
549
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
550
    sample_cnt = static_cast<int>(sample_indices.size());
551
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
552
    std::vector<std::vector<int>> sample_idx(ncol);
553
554
555

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
556
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
557
      auto idx = sample_indices[i];
558
559
560
561
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
562

563
564
565
566
567
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
568
        }
Guolin Ke's avatar
Guolin Ke committed
569
570
      }
    }
Guolin Ke's avatar
Guolin Ke committed
571
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
572
573
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
574
575
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
576
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
577
  } else {
578
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
579
    ret->CreateValid(
580
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
581
  }
582
583
584
585
586
587
588
589
590
591
592
593
594
595
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
596
597
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
598
  *out = ret.release();
599
  API_END();
600
601
}

Guolin Ke's avatar
Guolin Ke committed
602
int LGBM_DatasetCreateFromCSR(const void* indptr,
603
604
605
606
607
608
609
610
611
612
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
613
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
614
615
  auto param = Config::Str2Map(parameters);
  Config config;
616
617
618
619
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
620
  std::unique_ptr<Dataset> ret;
621
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
622
623
624
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
625
626
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
627
    auto sample_indices = rand.Sample(nrow, sample_cnt);
628
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
629
630
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
631
632
633
634
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
635
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
636
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
637
638
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
639
640
641
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
642
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
643
644
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
645
646
647
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
648
  } else {
649
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
650
    ret->CreateValid(
651
      reinterpret_cast<const Dataset*>(reference));
652
  }
653
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
654
  #pragma omp parallel for schedule(static)
655
  for (int i = 0; i < nindptr - 1; ++i) {
656
    OMP_LOOP_EX_BEGIN();
657
658
659
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
660
    OMP_LOOP_EX_END();
661
  }
662
  OMP_THROW_EX();
663
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
664
  *out = ret.release();
665
  API_END();
666
667
}

Guolin Ke's avatar
Guolin Ke committed
668
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
669
670
671
672
673
674
675
676
677
678
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
679
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
680
681
  auto param = Config::Str2Map(parameters);
  Config config;
682
683
684
685
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
686
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
687
688
689
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
690
691
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
692
    auto sample_indices = rand.Sample(nrow, sample_cnt);
693
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
694
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
695
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
696
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
697
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
698
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
699
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
700
701
702
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
703
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
704
705
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
706
707
        }
      }
708
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
709
    }
710
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
711
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
712
713
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
714
715
716
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
717
  } else {
718
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
719
    ret->CreateValid(
720
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
721
  }
722
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
723
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
724
  for (int i = 0; i < ncol_ptr - 1; ++i) {
725
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
726
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
727
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
728
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
729
730
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
731
732
733
734
735
736
737
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
738
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
739
    }
740
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
741
  }
742
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
743
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
744
  *out = ret.release();
745
  API_END();
Guolin Ke's avatar
Guolin Ke committed
746
747
}

Guolin Ke's avatar
Guolin Ke committed
748
int LGBM_DatasetGetSubset(
749
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
750
751
752
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
753
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
754
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
755
756
  auto param = Config::Str2Map(parameters);
  Config config;
757
758
759
760
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
761
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
762
  CHECK(num_used_row_indices > 0);
763
764
765
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
Guolin Ke's avatar
Guolin Ke committed
766
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
767
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
768
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
769
770
771
772
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
773
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
774
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
775
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
776
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
777
778
779
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
780
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
781
782
783
784
785
786
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
787
int LGBM_DatasetGetFeatureNames(
788
789
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
790
  int* num_feature_names) {
791
792
793
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
794
795
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
796
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
797
798
799
800
  }
  API_END();
}

801
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
802
int LGBM_DatasetFree(DatasetHandle handle) {
803
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
804
  delete reinterpret_cast<Dataset*>(handle);
805
  API_END();
806
807
}

Guolin Ke's avatar
Guolin Ke committed
808
int LGBM_DatasetSaveBinary(DatasetHandle handle,
809
                           const char* filename) {
810
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
811
812
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
813
  API_END();
814
815
}

Guolin Ke's avatar
Guolin Ke committed
816
int LGBM_DatasetSetField(DatasetHandle handle,
817
818
819
820
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
821
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
822
  auto dataset = reinterpret_cast<Dataset*>(handle);
823
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
824
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
825
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
826
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
827
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
828
829
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
830
  }
831
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
832
  API_END();
833
834
}

Guolin Ke's avatar
Guolin Ke committed
835
int LGBM_DatasetGetField(DatasetHandle handle,
836
837
838
839
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
840
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
841
  auto dataset = reinterpret_cast<Dataset*>(handle);
842
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
843
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
844
    *out_type = C_API_DTYPE_FLOAT32;
845
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
846
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
847
    *out_type = C_API_DTYPE_INT32;
848
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
849
850
851
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
852
  }
853
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
854
  if (*out_ptr == nullptr) { *out_len = 0; }
855
  API_END();
856
857
}

858
859
860
861
862
863
864
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
865
int LGBM_DatasetGetNumData(DatasetHandle handle,
866
                           int* out) {
867
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
868
869
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
870
  API_END();
871
872
}

Guolin Ke's avatar
Guolin Ke committed
873
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
874
                              int* out) {
875
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
876
877
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
878
  API_END();
Guolin Ke's avatar
Guolin Ke committed
879
}
880
881
882

// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
883
int LGBM_BoosterCreate(const DatasetHandle train_data,
884
885
                       const char* parameters,
                       BoosterHandle* out) {
886
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
887
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
888
889
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
890
  API_END();
891
892
}

Guolin Ke's avatar
Guolin Ke committed
893
int LGBM_BoosterCreateFromModelfile(
894
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
895
  int* out_num_iterations,
896
  BoosterHandle* out) {
897
  API_BEGIN();
wxchan's avatar
wxchan committed
898
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
899
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
900
  *out = ret.release();
901
  API_END();
902
903
}

Guolin Ke's avatar
Guolin Ke committed
904
int LGBM_BoosterLoadModelFromString(
905
906
907
908
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
909
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
910
911
912
913
914
915
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

916
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
917
int LGBM_BoosterFree(BoosterHandle handle) {
918
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
919
  delete reinterpret_cast<Booster*>(handle);
920
  API_END();
921
922
}

923
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
924
925
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
926
  ref_booster->ShuffleModels(start_iter, end_iter);
927
928
929
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
930
int LGBM_BoosterMerge(BoosterHandle handle,
931
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
932
933
934
935
936
937
938
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
939
int LGBM_BoosterAddValidData(BoosterHandle handle,
940
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
941
942
943
944
945
946
947
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
948
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
949
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
950
951
952
953
954
955
956
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
957
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
958
959
960
961
962
963
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
964
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
965
966
967
968
969
970
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
971
972
973
974
975
976
977
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
978
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
979
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
980
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
981
982
983
984
985
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
986
  API_END();
987
988
}

Guolin Ke's avatar
Guolin Ke committed
989
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
990
991
992
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
993
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
994
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
995
  #ifdef SCORE_T_USE_DOUBLE
996
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
997
  #else
998
999
1000
1001
1002
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1003
  #endif
1004
  API_END();
1005
1006
}

Guolin Ke's avatar
Guolin Ke committed
1007
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1008
1009
1010
1011
1012
1013
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1014
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1015
1016
1017
1018
1019
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1020

1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1035
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1036
1037
1038
1039
1040
1041
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1042
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1043
1044
1045
1046
1047
1048
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1049
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1050
1051
1052
1053
1054
1055
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1056
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1057
1058
1059
1060
1061
1062
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1063
int LGBM_BoosterGetEval(BoosterHandle handle,
1064
1065
1066
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1067
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1068
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1069
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1070
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1071
  *out_len = static_cast<int>(result_buf.size());
1072
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1073
    (out_results)[i] = static_cast<double>(result_buf[i]);
1074
  }
1075
  API_END();
1076
1077
}

Guolin Ke's avatar
Guolin Ke committed
1078
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1079
1080
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1081
1082
1083
1084
1085
1086
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1087
int LGBM_BoosterGetPredict(BoosterHandle handle,
1088
1089
1090
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1091
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1092
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1093
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1094
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1095
1096
}

Guolin Ke's avatar
Guolin Ke committed
1097
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1098
1099
1100
1101
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1102
                               const char* parameter,
1103
                               const char* result_filename) {
1104
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1105
1106
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1107
1108
1109
1110
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1111
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1112
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1113
                       config, result_filename);
1114
  API_END();
1115
1116
}

Guolin Ke's avatar
Guolin Ke committed
1117
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1118
1119
1120
1121
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1122
1123
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1124
1125
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1126
1127
1128
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1129
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
1140
                              const char* parameter,
1141
1142
                              int64_t* out_len,
                              double* out_result) {
1143
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1144
1145
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1146
1147
1148
1149
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1150
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1151
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1152
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
1153
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1154
                       config, out_result, out_len);
1155
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1156
}
1157

Guolin Ke's avatar
Guolin Ke committed
1158
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1169
                              const char* parameter,
1170
1171
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1172
1173
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1174
1175
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1186
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1187
1188
1189
1190
1191
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1192
1193
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1194
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1195
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1196
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1197
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1198
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1199
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1200
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1201
1202
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1203
1204
    return one_row;
  };
Guolin Ke's avatar
Guolin Ke committed
1205
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config,
cbecker's avatar
cbecker committed
1206
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1207
1208
1209
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1210
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1211
1212
1213
1214
1215
1216
1217
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1218
                              const char* parameter,
1219
1220
                              int64_t* out_len,
                              double* out_result) {
1221
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1222
1223
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1224
1225
1226
1227
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1228
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1229
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1230
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1231
                       config, out_result, out_len);
1232
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1233
}
1234

Guolin Ke's avatar
Guolin Ke committed
1235
int LGBM_BoosterSaveModel(BoosterHandle handle,
1236
                          int start_iteration,
1237
1238
                          int num_iteration,
                          const char* filename) {
1239
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1240
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1241
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1242
1243
1244
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1245
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1246
                                  int start_iteration,
1247
                                  int num_iteration,
1248
                                  int64_t buffer_len,
1249
                                  int64_t* out_len,
1250
                                  char* out_str) {
1251
1252
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1253
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1254
  *out_len = static_cast<int64_t>(model.size()) + 1;
1255
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1256
    std::memcpy(out_str, model.c_str(), *out_len);
1257
1258
1259
1260
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1261
int LGBM_BoosterDumpModel(BoosterHandle handle,
1262
                          int start_iteration,
1263
                          int num_iteration,
1264
1265
                          int64_t buffer_len,
                          int64_t* out_len,
1266
                          char* out_str) {
wxchan's avatar
wxchan committed
1267
1268
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1269
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1270
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1271
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1272
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1273
  }
1274
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1275
}
1276

Guolin Ke's avatar
Guolin Ke committed
1277
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1278
1279
1280
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1281
1282
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1283
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1284
1285
1286
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1287
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1288
1289
1290
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1291
1292
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1293
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1294
1295
1296
  API_END();
}

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1310
1311
1312
1313
1314
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1315
  Config config;
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1332
1333
1334
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1335
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1336
  if (num_machines > 1) {
1337
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1338
1339
1340
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1341

Guolin Ke's avatar
Guolin Ke committed
1342
// ---- start of some help functions
1343
1344
1345

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1346
  if (data_type == C_API_DTYPE_FLOAT32) {
1347
1348
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1349
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1350
        std::vector<double> ret(num_col);
1351
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1352
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1353
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1354
1355
1356
1357
        }
        return ret;
      };
    } else {
1358
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1359
        std::vector<double> ret(num_col);
1360
        for (int i = 0; i < num_col; ++i) {
1361
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1362
1363
1364
1365
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1366
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1367
1368
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1369
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1370
        std::vector<double> ret(num_col);
1371
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1372
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1373
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1374
1375
1376
1377
        }
        return ret;
      };
    } else {
1378
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1379
        std::vector<double> ret(num_col);
1380
        for (int i = 0; i < num_col; ++i) {
1381
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1382
1383
1384
1385
1386
        }
        return ret;
      };
    }
  }
1387
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1388
1389
1390
1391
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1392
1393
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1394
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1395
1396
1397
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1398
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1399
          ret.emplace_back(i, raw_values[i]);
1400
        }
Guolin Ke's avatar
Guolin Ke committed
1401
1402
1403
      }
      return ret;
    };
1404
  }
Guolin Ke's avatar
Guolin Ke committed
1405
  return nullptr;
1406
1407
1408
}

std::function<std::vector<std::pair<int, double>>(int idx)>
1409
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1410
  if (data_type == C_API_DTYPE_FLOAT32) {
1411
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1412
    if (indptr_type == C_API_DTYPE_INT32) {
1413
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1414
      return [=] (int idx) {
1415
1416
1417
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1418
1419
1420
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1421
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1422
          ret.emplace_back(indices[i], data_ptr[i]);
1423
1424
1425
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1426
    } else if (indptr_type == C_API_DTYPE_INT64) {
1427
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1428
      return [=] (int idx) {
1429
1430
1431
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1432
1433
1434
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1435
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1436
          ret.emplace_back(indices[i], data_ptr[i]);
1437
1438
1439
1440
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1441
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1442
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1443
    if (indptr_type == C_API_DTYPE_INT32) {
1444
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1445
      return [=] (int idx) {
1446
1447
1448
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1449
1450
1451
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1452
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1453
          ret.emplace_back(indices[i], data_ptr[i]);
1454
1455
1456
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1457
    } else if (indptr_type == C_API_DTYPE_INT64) {
1458
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1459
      return [=] (int idx) {
1460
1461
1462
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1463
1464
1465
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1466
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1467
          ret.emplace_back(indices[i], data_ptr[i]);
1468
1469
1470
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1471
1472
    }
  }
1473
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1474
1475
}

Guolin Ke's avatar
Guolin Ke committed
1476
std::function<std::pair<int, double>(int idx)>
1477
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1478
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1479
  if (data_type == C_API_DTYPE_FLOAT32) {
1480
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1481
    if (col_ptr_type == C_API_DTYPE_INT32) {
1482
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1483
1484
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1485
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1486
1487
1488
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1489
        }
Guolin Ke's avatar
Guolin Ke committed
1490
1491
1492
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1493
      };
Guolin Ke's avatar
Guolin Ke committed
1494
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1495
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1496
1497
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1498
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1499
1500
1501
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1502
        }
Guolin Ke's avatar
Guolin Ke committed
1503
1504
1505
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1506
      };
Guolin Ke's avatar
Guolin Ke committed
1507
    }
Guolin Ke's avatar
Guolin Ke committed
1508
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1509
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1510
    if (col_ptr_type == C_API_DTYPE_INT32) {
1511
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1512
1513
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1514
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1515
1516
1517
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1518
        }
Guolin Ke's avatar
Guolin Ke committed
1519
1520
1521
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1522
      };
Guolin Ke's avatar
Guolin Ke committed
1523
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1524
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1525
1526
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1527
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1528
1529
1530
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1531
        }
Guolin Ke's avatar
Guolin Ke committed
1532
1533
1534
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1535
      };
Guolin Ke's avatar
Guolin Ke committed
1536
1537
    }
  }
1538
  throw std::runtime_error("Unknown data type in CSC matrix");
1539
1540
}

Guolin Ke's avatar
Guolin Ke committed
1541
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1542
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1543
1544
1545
1546
1547
1548
1549
1550
1551
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1552
    }
Guolin Ke's avatar
Guolin Ke committed
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1570
    }
Guolin Ke's avatar
Guolin Ke committed
1571
1572
1573
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1574
  }
Guolin Ke's avatar
Guolin Ke committed
1575
}