c_api.cpp 52.9 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>
cbecker's avatar
cbecker committed
13
#include <LightGBM/prediction_early_stop.h>
14
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
20
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
21
#include <stdexcept>
wxchan's avatar
wxchan committed
22
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
23
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
24

Guolin Ke's avatar
Guolin Ke committed
25
26
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
27
28
29
30
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
31
  explicit Booster(const char* filename) {
32
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
33
34
  }

Guolin Ke's avatar
Guolin Ke committed
35
  Booster(const Dataset* train_data,
36
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
37
    CHECK(train_data->num_features() > 0);
wxchan's avatar
wxchan committed
38
39
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
40
41
42
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
43
44
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
45
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
46
47
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
48

49
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
50

51
52
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
54
55
56
    if (config_.boosting_config.tree_learner_type == std::string("feature")) {
      Log::Fatal("Do not support feature parallel in c api.");
    }
Guolin Ke's avatar
Guolin Ke committed
57
    if (Network::num_machines() == 1 && config_.boosting_config.tree_learner_type != std::string("serial")) {
58
59
60
      Log::Warning("Only find one worker, will switch to serial tree learner.");
      config_.boosting_config.tree_learner_type = "serial";
    }
61
    boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(),
62
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
63

wxchan's avatar
wxchan committed
64
65
66
67
68
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
69
70
71
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
72

Guolin Ke's avatar
Guolin Ke committed
73
  }
74

75
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
76
77
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
78
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
97
98
99
100
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
Guolin Ke's avatar
Guolin Ke committed
101
      CHECK(train_data->num_features() > 0);
102
103
104
105
106
107
108
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
109
110
111
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
112
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
113
114
115
116
117
118
119
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
120
121
122
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
123
124

    config_.Set(param);
125
126
127
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
128
129
130
131

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
132
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
133
134
135
136
137
138
139
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
140
141
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
142
    }
Guolin Ke's avatar
Guolin Ke committed
143

144
    boosting_->ResetConfig(&config_.boosting_config);
Guolin Ke's avatar
Guolin Ke committed
145

wxchan's avatar
wxchan committed
146
147
148
149
150
151
152
153
154
155
156
157
158
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
159
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
160
  }
Guolin Ke's avatar
Guolin Ke committed
161

162
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
163
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
164
    return boosting_->TrainOneIter(nullptr, nullptr);
165
166
167
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
168
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
169
    return boosting_->TrainOneIter(gradients, hessians);
170
171
  }

wxchan's avatar
wxchan committed
172
173
174
175
176
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
177
178
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
179
               const IOConfig& config,
Guolin Ke's avatar
Guolin Ke committed
180
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
182
183
    bool is_predict_leaf = false;
    bool is_raw_score = false;
184
    bool is_predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
185
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
186
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
187
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
188
      is_raw_score = true;
189
190
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      is_predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
191
192
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
193
    }
Guolin Ke's avatar
Guolin Ke committed
194

195
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib,
196
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Tony-Y's avatar
Tony-Y committed
197
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, is_predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
198
    auto pred_fun = predictor.GetPredictFunction();
199
200
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
201
    for (int i = 0; i < nrow; ++i) {
202
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
203
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
204
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
205
      pred_fun(one_row, pred_wrt_ptr);
206
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
207
    }
208
    OMP_THROW_EX();
Tony-Y's avatar
Tony-Y committed
209
    *out_len = nrow * num_pred_in_one_row;
Guolin Ke's avatar
Guolin Ke committed
210
211
212
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
213
               int data_has_header, const IOConfig& config,
cbecker's avatar
cbecker committed
214
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
215
216
217
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
218
    bool is_predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
223
224
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      is_predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
225
226
227
    } else {
      is_raw_score = false;
    }
228
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib,
229
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
230
231
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
232
233
  }

Guolin Ke's avatar
Guolin Ke committed
234
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
235
236
237
238
239
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
240
  }
241

242
  void LoadModelFromString(const char* model_str) {
243
244
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
245
246
247
248
249
250
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

251
252
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
253
  }
254

255
256
257
258
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
259
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
260
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
261
262
263
264
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
265
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
266
267
  }

wxchan's avatar
wxchan committed
268
269
270
271
272
273
274
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
275

wxchan's avatar
wxchan committed
276
277
278
279
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
280
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
281
282
283
284
285
286
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
287
288
289
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
290
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
291
292
293
294
295
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
296
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
297

Guolin Ke's avatar
Guolin Ke committed
298
private:
299

wxchan's avatar
wxchan committed
300
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
301
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
302
303
304
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
305
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
306
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
307
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
308
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
309
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
310
311
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
312
313
314
};

}
Guolin Ke's avatar
Guolin Ke committed
315
316
317

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
318
319
320
321
322
323
324
325
326
327
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
328
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
329
330
331
332
333

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
334
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
350
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
351
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
352
353
}

Guolin Ke's avatar
Guolin Ke committed
354
int LGBM_DatasetCreateFromFile(const char* filename,
355
356
357
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
358
  API_BEGIN();
wxchan's avatar
wxchan committed
359
  auto param = ConfigBase::Str2Map(parameters);
360
361
362
363
364
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
365
  DatasetLoader loader(config.io_config,nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
366
  if (reference == nullptr) {
367
368
369
370
371
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
372
  } else {
373
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
374
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
375
  }
376
  API_END();
Guolin Ke's avatar
Guolin Ke committed
377
378
}

379

Guolin Ke's avatar
Guolin Ke committed
380
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
381
382
383
384
385
386
387
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
388
389
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
390
391
392
393
394
395
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
396
397
398
399
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
400
401
}

402

Guolin Ke's avatar
Guolin Ke committed
403
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
404
405
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
406
407
408
409
410
411
412
413
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
414
int LGBM_DatasetPushRows(DatasetHandle dataset,
415
416
417
418
419
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
420
421
422
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
423
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
424
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
425
  for (int i = 0; i < nrow; ++i) {
426
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
427
428
429
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
430
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
431
  }
432
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
433
434
435
436
437
438
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
439
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
440
441
442
443
444
445
446
447
448
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
449
450
451
452
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
453
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
454
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
455
  for (int i = 0; i < nrow; ++i) {
456
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
457
458
459
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
460
                          static_cast<data_size_t>(start_row + i), one_row);
461
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
462
  }
463
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
464
465
466
467
468
469
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
470
int LGBM_DatasetCreateFromMat(const void* data,
471
472
473
474
475
476
477
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
478
  API_BEGIN();
wxchan's avatar
wxchan committed
479
  auto param = ConfigBase::Str2Map(parameters);
480
481
482
483
484
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
485
  std::unique_ptr<Dataset> ret;
486
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
487
488
  if (reference == nullptr) {
    // sample data first
489
490
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
491
    auto sample_indices = rand.Sample(nrow, sample_cnt);
492
    sample_cnt = static_cast<int>(sample_indices.size());
493
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
494
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
495
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
496
      auto idx = sample_indices[i];
497
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
498
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
499
        if (std::fabs(row[j]) > kEpsilon || std::isnan(row[j])) {
Guolin Ke's avatar
Guolin Ke committed
500
501
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
502
        }
Guolin Ke's avatar
Guolin Ke committed
503
504
      }
    }
505
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
506
507
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
508
509
510
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
511
  } else {
512
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
513
    ret->CreateValid(
514
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
515
  }
516
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
517
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
518
  for (int i = 0; i < nrow; ++i) {
519
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
520
    const int tid = omp_get_thread_num();
521
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
522
    ret->PushOneRow(tid, i, one_row);
523
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
524
  }
525
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
526
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
527
  *out = ret.release();
528
  API_END();
529
530
}

Guolin Ke's avatar
Guolin Ke committed
531
int LGBM_DatasetCreateFromCSR(const void* indptr,
532
533
534
535
536
537
538
539
540
541
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
542
  API_BEGIN();
wxchan's avatar
wxchan committed
543
  auto param = ConfigBase::Str2Map(parameters);
544
545
546
547
548
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
549
  std::unique_ptr<Dataset> ret;
550
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
551
552
553
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
554
555
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
556
    auto sample_indices = rand.Sample(nrow, sample_cnt);
557
    sample_cnt = static_cast<int>(sample_indices.size());
558
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
559
    std::vector<std::vector<int>> sample_idx;
560
561
562
563
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
564
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
565
566
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
567
        }
Guolin Ke's avatar
Guolin Ke committed
568
        if (std::fabs(inner_data.second) > kEpsilon || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
569
570
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
571
572
573
        }
      }
    }
574
    CHECK(num_col >= static_cast<int>(sample_values.size()));
575
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
576
577
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
578
579
580
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
581
  } else {
582
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
583
    ret->CreateValid(
584
      reinterpret_cast<const Dataset*>(reference));
585
  }
586
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
587
  #pragma omp parallel for schedule(static)
588
  for (int i = 0; i < nindptr - 1; ++i) {
589
    OMP_LOOP_EX_BEGIN();
590
591
592
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
593
    OMP_LOOP_EX_END();
594
  }
595
  OMP_THROW_EX();
596
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
597
  *out = ret.release();
598
  API_END();
599
600
}

Guolin Ke's avatar
Guolin Ke committed
601
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
602
603
604
605
606
607
608
609
610
611
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
612
  API_BEGIN();
wxchan's avatar
wxchan committed
613
  auto param = ConfigBase::Str2Map(parameters);
614
615
616
617
618
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
619
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
620
621
622
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
623
624
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
625
    auto sample_indices = rand.Sample(nrow, sample_cnt);
626
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
627
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
628
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
629
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
630
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
631
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
632
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
633
634
635
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
636
        if (std::fabs(val) > kEpsilon || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
637
638
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
639
640
        }
      }
641
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
642
    }
643
    OMP_THROW_EX();
644
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
645
646
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
647
648
649
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
650
  } else {
651
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
652
    ret->CreateValid(
653
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
654
  }
655
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
656
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
657
  for (int i = 0; i < ncol_ptr - 1; ++i) {
658
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
659
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
660
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
661
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
662
663
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
664
665
666
667
668
669
670
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
671
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
672
    }
673
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
674
  }
675
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
676
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
677
  *out = ret.release();
678
  API_END();
Guolin Ke's avatar
Guolin Ke committed
679
680
}

Guolin Ke's avatar
Guolin Ke committed
681
int LGBM_DatasetGetSubset(
682
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
683
684
685
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
686
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
687
688
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
689
690
691
692
693
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
694
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
695
  CHECK(num_used_row_indices > 0);
696
697
698
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
Guolin Ke's avatar
Guolin Ke committed
699
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
700
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
701
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
702
703
704
705
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
706
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
707
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
708
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
709
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
710
711
712
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
713
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
714
715
716
717
718
719
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
720
int LGBM_DatasetGetFeatureNames(
721
722
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
723
  int* num_feature_names) {
724
725
726
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
727
728
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
729
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
730
731
732
733
  }
  API_END();
}

734
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
735
int LGBM_DatasetFree(DatasetHandle handle) {
736
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
737
  delete reinterpret_cast<Dataset*>(handle);
738
  API_END();
739
740
}

Guolin Ke's avatar
Guolin Ke committed
741
int LGBM_DatasetSaveBinary(DatasetHandle handle,
742
                           const char* filename) {
743
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
744
745
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
746
  API_END();
747
748
}

Guolin Ke's avatar
Guolin Ke committed
749
int LGBM_DatasetSetField(DatasetHandle handle,
750
751
752
753
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
754
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
755
  auto dataset = reinterpret_cast<Dataset*>(handle);
756
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
757
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
758
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
759
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
760
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
761
762
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
763
  }
764
765
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
766
767
}

Guolin Ke's avatar
Guolin Ke committed
768
int LGBM_DatasetGetField(DatasetHandle handle,
769
770
771
772
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
773
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
774
  auto dataset = reinterpret_cast<Dataset*>(handle);
775
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
776
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
777
    *out_type = C_API_DTYPE_FLOAT32;
778
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
779
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
780
    *out_type = C_API_DTYPE_INT32;
781
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
782
783
784
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
785
  }
786
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
787
  if (*out_ptr == nullptr) { *out_len = 0; }
788
  API_END();
789
790
}

Guolin Ke's avatar
Guolin Ke committed
791
int LGBM_DatasetGetNumData(DatasetHandle handle,
792
                           int* out) {
793
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
794
795
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
796
  API_END();
797
798
}

Guolin Ke's avatar
Guolin Ke committed
799
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
800
                              int* out) {
801
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
802
803
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
804
  API_END();
Guolin Ke's avatar
Guolin Ke committed
805
}
806
807
808

// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
809
int LGBM_BoosterCreate(const DatasetHandle train_data,
810
811
                       const char* parameters,
                       BoosterHandle* out) {
812
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
813
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
814
815
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
816
  API_END();
817
818
}

Guolin Ke's avatar
Guolin Ke committed
819
int LGBM_BoosterCreateFromModelfile(
820
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
821
  int* out_num_iterations,
822
  BoosterHandle* out) {
823
  API_BEGIN();
wxchan's avatar
wxchan committed
824
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
825
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
826
  *out = ret.release();
827
  API_END();
828
829
}

Guolin Ke's avatar
Guolin Ke committed
830
int LGBM_BoosterLoadModelFromString(
831
832
833
834
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
835
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
836
837
838
839
840
841
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

842
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
843
int LGBM_BoosterFree(BoosterHandle handle) {
844
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
845
  delete reinterpret_cast<Booster*>(handle);
846
  API_END();
847
848
}

Guolin Ke's avatar
Guolin Ke committed
849
int LGBM_BoosterMerge(BoosterHandle handle,
850
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
851
852
853
854
855
856
857
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
858
int LGBM_BoosterAddValidData(BoosterHandle handle,
859
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
860
861
862
863
864
865
866
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
867
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
868
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
869
870
871
872
873
874
875
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
876
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
877
878
879
880
881
882
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
883
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
884
885
886
887
888
889
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
890
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
891
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
892
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
893
894
895
896
897
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
898
  API_END();
899
900
}

Guolin Ke's avatar
Guolin Ke committed
901
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
902
903
904
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
905
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
906
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
907
908
909
910
911
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
912
  API_END();
913
914
}

Guolin Ke's avatar
Guolin Ke committed
915
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
916
917
918
919
920
921
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
922
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
923
924
925
926
927
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
928

Guolin Ke's avatar
Guolin Ke committed
929
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
930
931
932
933
934
935
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
936
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
937
938
939
940
941
942
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
943
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
944
945
946
947
948
949
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
950
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
951
952
953
954
955
956
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
957
int LGBM_BoosterGetEval(BoosterHandle handle,
958
959
960
                        int data_idx,
                        int* out_len,
                        double* out_results) {
961
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
962
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
963
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
964
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
965
  *out_len = static_cast<int>(result_buf.size());
966
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
967
    (out_results)[i] = static_cast<double>(result_buf[i]);
968
  }
969
  API_END();
970
971
}

Guolin Ke's avatar
Guolin Ke committed
972
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
973
974
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
975
976
977
978
979
980
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
981
int LGBM_BoosterGetPredict(BoosterHandle handle,
982
983
984
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
985
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
986
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
987
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
988
  API_END();
Guolin Ke's avatar
Guolin Ke committed
989
990
}

Guolin Ke's avatar
Guolin Ke committed
991
int LGBM_BoosterPredictForFile(BoosterHandle handle,
992
993
994
995
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
996
                               const char* parameter,
997
                               const char* result_filename) {
998
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
999
1000
1001
1002
1003
1004
  auto param = ConfigBase::Str2Map(parameter);
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1005
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1006
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1007
                       config.io_config, result_filename);
1008
  API_END();
1009
1010
}

Guolin Ke's avatar
Guolin Ke committed
1011
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1012
1013
1014
1015
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1016
1017
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1018
  *out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow(
1019
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB));
Guolin Ke's avatar
Guolin Ke committed
1020
1021
1022
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1023
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
1034
                              const char* parameter,
1035
1036
                              int64_t* out_len,
                              double* out_result) {
1037
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1038
1039
1040
1041
1042
1043
  auto param = ConfigBase::Str2Map(parameter);
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1044
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1045
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1046
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
1047
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1048
                       config.io_config, out_result, out_len);
1049
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1050
}
1051

Guolin Ke's avatar
Guolin Ke committed
1052
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1063
                              const char* parameter,
1064
1065
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1066
1067
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
  auto param = ConfigBase::Str2Map(parameter);
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1080
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1081
1082
1083
1084
1085
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1086
1087
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1088
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1089
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1090
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1091
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1092
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1093
      if (std::fabs(val) > kEpsilon || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1094
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1095
1096
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1097
1098
    return one_row;
  };
Guolin Ke's avatar
Guolin Ke committed
1099
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config.io_config,
cbecker's avatar
cbecker committed
1100
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1101
1102
1103
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1104
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1105
1106
1107
1108
1109
1110
1111
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1112
                              const char* parameter,
1113
1114
                              int64_t* out_len,
                              double* out_result) {
1115
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1116
1117
1118
1119
1120
1121
  auto param = ConfigBase::Str2Map(parameter);
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1122
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1123
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1124
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1125
                       config.io_config, out_result, out_len);
1126
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1127
}
1128

Guolin Ke's avatar
Guolin Ke committed
1129
int LGBM_BoosterSaveModel(BoosterHandle handle,
1130
1131
                          int num_iteration,
                          const char* filename) {
1132
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1133
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1134
1135
1136
1137
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1138
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1139
                                  int num_iteration,
1140
1141
                                  int64_t buffer_len, 
                                  int64_t* out_len,
1142
                                  char* out_str) {
1143
1144
1145
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
1146
  *out_len = static_cast<int64_t>(model.size()) + 1;
1147
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1148
    std::memcpy(out_str, model.c_str(), *out_len);
1149
1150
1151
1152
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1153
int LGBM_BoosterDumpModel(BoosterHandle handle,
1154
                          int num_iteration,
1155
1156
                          int64_t buffer_len,
                          int64_t* out_len,
1157
                          char* out_str) {
wxchan's avatar
wxchan committed
1158
1159
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1160
  std::string model = ref_booster->DumpModel(num_iteration);
1161
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1162
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1163
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1164
  }
1165
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1166
}
1167

Guolin Ke's avatar
Guolin Ke committed
1168
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1169
1170
1171
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1172
1173
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1174
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1175
1176
1177
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1178
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1179
1180
1181
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1182
1183
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1184
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1185
1186
1187
  API_END();
}

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
  NetworkConfig config;
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

ww's avatar
ww committed
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
int LGBM_GetFuncions(void* AllreduceFuncPtr,
                     void* ReduceScatterFuncPtr,
                     void* AllgatherFuncPtr,
                     int num_machines,
                     int rank) {
  API_BEGIN();
  if(num_machines > 1) {
    auto func1 = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) {
      auto ptr = *func.target<ReduceFunctionInC>();
      auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionInC&))AllreduceFuncPtr;
      return tmp(arg1, arg2, arg3, arg4, ptr);
    };
    Network::SetAllReduce(func1);
    auto func2 = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) {
      auto ptr = *func.target<ReduceFunctionInC>();
      auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionInC&))ReduceScatterFuncPtr;
      return tmp(arg1, arg2, arg3, arg4, arg5, ptr);
    };
    Network::SetReduceScatter(func2);
    Network::SetAllgather((void(*)(char*, int, char*))AllgatherFuncPtr);
    Network::SetNumMachines(num_machines);
    Network::SetRank(rank);
  }
  API_END();

}
Guolin Ke's avatar
Guolin Ke committed
1249
// ---- start of some help functions
1250
1251
1252

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1253
  if (data_type == C_API_DTYPE_FLOAT32) {
1254
1255
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1256
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1257
        std::vector<double> ret(num_col);
1258
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1259
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1260
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1261
1262
1263
1264
        }
        return ret;
      };
    } else {
1265
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1266
        std::vector<double> ret(num_col);
1267
        for (int i = 0; i < num_col; ++i) {
1268
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1269
1270
1271
1272
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1273
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1274
1275
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1276
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1277
        std::vector<double> ret(num_col);
1278
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1279
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1280
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1281
1282
1283
1284
        }
        return ret;
      };
    } else {
1285
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1286
        std::vector<double> ret(num_col);
1287
        for (int i = 0; i < num_col; ++i) {
1288
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1289
1290
1291
1292
1293
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1294
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1295
1296
1297
1298
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1299
1300
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1301
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1302
1303
1304
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1305
        if (std::fabs(raw_values[i]) > kEpsilon || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1306
          ret.emplace_back(i, raw_values[i]);
1307
        }
Guolin Ke's avatar
Guolin Ke committed
1308
1309
1310
      }
      return ret;
    };
1311
  }
Guolin Ke's avatar
Guolin Ke committed
1312
  return nullptr;
1313
1314
1315
1316
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1317
  if (data_type == C_API_DTYPE_FLOAT32) {
1318
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1319
    if (indptr_type == C_API_DTYPE_INT32) {
1320
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1321
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1322
1323
1324
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1325
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1326
          ret.emplace_back(indices[i], data_ptr[i]);
1327
1328
1329
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1330
    } else if (indptr_type == C_API_DTYPE_INT64) {
1331
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1332
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1333
1334
1335
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1336
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1337
          ret.emplace_back(indices[i], data_ptr[i]);
1338
1339
1340
1341
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1342
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1343
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1344
    if (indptr_type == C_API_DTYPE_INT32) {
1345
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1346
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1347
1348
1349
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1350
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1351
          ret.emplace_back(indices[i], data_ptr[i]);
1352
1353
1354
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1355
    } else if (indptr_type == C_API_DTYPE_INT64) {
1356
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1357
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1358
1359
1360
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1361
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1362
          ret.emplace_back(indices[i], data_ptr[i]);
1363
1364
1365
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1366
1367
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1368
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1369
1370
}

Guolin Ke's avatar
Guolin Ke committed
1371
1372
1373
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1374
  if (data_type == C_API_DTYPE_FLOAT32) {
1375
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1376
    if (col_ptr_type == C_API_DTYPE_INT32) {
1377
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1378
1379
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1380
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1381
1382
1383
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1384
        }
Guolin Ke's avatar
Guolin Ke committed
1385
1386
1387
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1388
      };
Guolin Ke's avatar
Guolin Ke committed
1389
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1390
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1391
1392
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1393
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1394
1395
1396
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1397
        }
Guolin Ke's avatar
Guolin Ke committed
1398
1399
1400
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1401
      };
Guolin Ke's avatar
Guolin Ke committed
1402
    }
Guolin Ke's avatar
Guolin Ke committed
1403
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1404
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1405
    if (col_ptr_type == C_API_DTYPE_INT32) {
1406
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1407
1408
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1409
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1410
1411
1412
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1413
        }
Guolin Ke's avatar
Guolin Ke committed
1414
1415
1416
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1417
      };
Guolin Ke's avatar
Guolin Ke committed
1418
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1419
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1420
1421
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1422
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1423
1424
1425
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1426
        }
Guolin Ke's avatar
Guolin Ke committed
1427
1428
1429
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1430
      };
Guolin Ke's avatar
Guolin Ke committed
1431
1432
1433
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1434
1435
}

Guolin Ke's avatar
Guolin Ke committed
1436
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1437
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1438
1439
1440
1441
1442
1443
1444
1445
1446
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1447
    }
Guolin Ke's avatar
Guolin Ke committed
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1465
    }
Guolin Ke's avatar
Guolin Ke committed
1466
1467
1468
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1469
  }
Guolin Ke's avatar
Guolin Ke committed
1470
}