c_api.cpp 52.3 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>
cbecker's avatar
cbecker committed
13
#include <LightGBM/prediction_early_stop.h>
14
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
20
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
21
#include <stdexcept>
wxchan's avatar
wxchan committed
22
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
23
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
24

Guolin Ke's avatar
Guolin Ke committed
25
26
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
27
28
29
30
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
31
  explicit Booster(const char* filename) {
32
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
33
34
  }

Guolin Ke's avatar
Guolin Ke committed
35
  Booster(const Dataset* train_data,
36
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
37
    CHECK(train_data->num_features() > 0);
wxchan's avatar
wxchan committed
38
39
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
40
41
42
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
43
44
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
45
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
46
47
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
48

49
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
50

51
52
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
54
55
56
    if (config_.boosting_config.tree_learner_type == std::string("feature")) {
      Log::Fatal("Do not support feature parallel in c api.");
    }
Guolin Ke's avatar
Guolin Ke committed
57
    if (Network::num_machines() == 1 && config_.boosting_config.tree_learner_type != std::string("serial")) {
58
59
60
      Log::Warning("Only find one worker, will switch to serial tree learner.");
      config_.boosting_config.tree_learner_type = "serial";
    }
61
    boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(),
62
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
63

wxchan's avatar
wxchan committed
64
65
66
67
68
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
69
70
71
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
72

Guolin Ke's avatar
Guolin Ke committed
73
  }
74

75
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
76
77
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
78
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
97
98
99
100
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
Guolin Ke's avatar
Guolin Ke committed
101
      CHECK(train_data->num_features() > 0);
102
103
104
105
106
107
108
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
109
110
111
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
112
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
113
114
115
116
117
118
119
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
120
121
122
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
123
124

    config_.Set(param);
125
126
127
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
128
129
130
131

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
132
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
133
134
135
136
137
138
139
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
140
141
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
142
    }
Guolin Ke's avatar
Guolin Ke committed
143

144
    boosting_->ResetConfig(&config_.boosting_config);
Guolin Ke's avatar
Guolin Ke committed
145

wxchan's avatar
wxchan committed
146
147
148
149
150
151
152
153
154
155
156
157
158
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
159
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
160
  }
Guolin Ke's avatar
Guolin Ke committed
161

162
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
163
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
164
    return boosting_->TrainOneIter(nullptr, nullptr);
165
166
167
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
168
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
169
    return boosting_->TrainOneIter(gradients, hessians);
170
171
  }

wxchan's avatar
wxchan committed
172
173
174
175
176
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
177
178
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
179
               const IOConfig& config,
Guolin Ke's avatar
Guolin Ke committed
180
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
181
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
182
183
    bool is_predict_leaf = false;
    bool is_raw_score = false;
184
    bool is_predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
185
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
186
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
187
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
188
      is_raw_score = true;
189
190
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      is_predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
191
192
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
193
    }
Guolin Ke's avatar
Guolin Ke committed
194

195
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib,
196
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Tony-Y's avatar
Tony-Y committed
197
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, is_predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
198
    auto pred_fun = predictor.GetPredictFunction();
199
200
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
201
    for (int i = 0; i < nrow; ++i) {
202
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
203
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
204
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
205
      pred_fun(one_row, pred_wrt_ptr);
206
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
207
    }
208
    OMP_THROW_EX();
Tony-Y's avatar
Tony-Y committed
209
    *out_len = nrow * num_pred_in_one_row;
Guolin Ke's avatar
Guolin Ke committed
210
211
212
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
213
               int data_has_header, const IOConfig& config,
cbecker's avatar
cbecker committed
214
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
215
216
217
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
218
    bool is_predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
223
224
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      is_predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
225
226
227
    } else {
      is_raw_score = false;
    }
228
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib,
229
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
230
231
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
232
233
  }

Guolin Ke's avatar
Guolin Ke committed
234
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
235
236
237
238
239
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
240
  }
241

242
  void LoadModelFromString(const char* model_str) {
243
244
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
245
246
247
248
249
250
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

251
252
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
253
  }
254

255
256
257
258
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
259
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
260
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
261
262
263
264
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
265
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
266
267
  }

wxchan's avatar
wxchan committed
268
269
270
271
272
273
274
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
275

wxchan's avatar
wxchan committed
276
277
278
279
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
280
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
281
282
283
284
285
286
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
287
288
289
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
290
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
291
292
293
294
295
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
296
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
297

Guolin Ke's avatar
Guolin Ke committed
298
private:
299

wxchan's avatar
wxchan committed
300
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
301
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
302
303
304
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
305
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
306
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
307
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
308
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
309
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
310
311
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
312
313
314
};

}
Guolin Ke's avatar
Guolin Ke committed
315
316
317

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
318
319
320
321
322
323
324
325
326
327
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
328
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
329
330
331
332
333

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
334
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
350
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
351
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
352
353
}

Guolin Ke's avatar
Guolin Ke committed
354
int LGBM_DatasetCreateFromFile(const char* filename,
355
356
357
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
358
  API_BEGIN();
wxchan's avatar
wxchan committed
359
  auto param = ConfigBase::Str2Map(parameters);
360
361
362
363
364
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
365
  DatasetLoader loader(config.io_config,nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
366
  if (reference == nullptr) {
367
368
369
370
371
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
372
  } else {
373
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
374
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
375
  }
376
  API_END();
Guolin Ke's avatar
Guolin Ke committed
377
378
}

379

Guolin Ke's avatar
Guolin Ke committed
380
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
381
382
383
384
385
386
387
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
388
389
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
390
391
392
393
394
395
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
396
397
398
399
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
400
401
}

402

Guolin Ke's avatar
Guolin Ke committed
403
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
404
405
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
406
407
408
409
410
411
412
413
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
414
int LGBM_DatasetPushRows(DatasetHandle dataset,
415
416
417
418
419
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
420
421
422
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
423
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
424
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
425
  for (int i = 0; i < nrow; ++i) {
426
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
427
428
429
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
430
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
431
  }
432
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
433
434
435
436
437
438
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
439
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
440
441
442
443
444
445
446
447
448
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
449
450
451
452
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
453
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
454
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
455
  for (int i = 0; i < nrow; ++i) {
456
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
457
458
459
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
460
                          static_cast<data_size_t>(start_row + i), one_row);
461
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
462
  }
463
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
464
465
466
467
468
469
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
470
int LGBM_DatasetCreateFromMat(const void* data,
471
472
473
474
475
476
477
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
478
  API_BEGIN();
wxchan's avatar
wxchan committed
479
  auto param = ConfigBase::Str2Map(parameters);
480
481
482
483
484
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
485
  std::unique_ptr<Dataset> ret;
486
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
487
488
  if (reference == nullptr) {
    // sample data first
489
490
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
491
    auto sample_indices = rand.Sample(nrow, sample_cnt);
492
    sample_cnt = static_cast<int>(sample_indices.size());
493
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
494
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
495
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
496
      auto idx = sample_indices[i];
497
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
498
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
499
        if (std::fabs(row[j]) > kZeroThreshold || std::isnan(row[j])) {
Guolin Ke's avatar
Guolin Ke committed
500
501
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
502
        }
Guolin Ke's avatar
Guolin Ke committed
503
504
      }
    }
505
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
506
507
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
508
509
510
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
511
  } else {
512
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
513
    ret->CreateValid(
514
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
515
  }
516
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
517
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
518
  for (int i = 0; i < nrow; ++i) {
519
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
520
    const int tid = omp_get_thread_num();
521
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
522
    ret->PushOneRow(tid, i, one_row);
523
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
524
  }
525
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
526
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
527
  *out = ret.release();
528
  API_END();
529
530
}

Guolin Ke's avatar
Guolin Ke committed
531
int LGBM_DatasetCreateFromCSR(const void* indptr,
532
533
534
535
536
537
538
539
540
541
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
542
  API_BEGIN();
wxchan's avatar
wxchan committed
543
  auto param = ConfigBase::Str2Map(parameters);
544
545
546
547
548
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
549
  std::unique_ptr<Dataset> ret;
550
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
551
552
553
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
554
555
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
556
    auto sample_indices = rand.Sample(nrow, sample_cnt);
557
    sample_cnt = static_cast<int>(sample_indices.size());
558
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
559
    std::vector<std::vector<int>> sample_idx;
560
561
562
563
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
564
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
565
566
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
567
        }
Guolin Ke's avatar
Guolin Ke committed
568
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
569
570
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
571
572
573
        }
      }
    }
574
    CHECK(num_col >= static_cast<int>(sample_values.size()));
575
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
576
577
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
578
579
580
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
581
  } else {
582
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
583
    ret->CreateValid(
584
      reinterpret_cast<const Dataset*>(reference));
585
  }
586
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
587
  #pragma omp parallel for schedule(static)
588
  for (int i = 0; i < nindptr - 1; ++i) {
589
    OMP_LOOP_EX_BEGIN();
590
591
592
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
593
    OMP_LOOP_EX_END();
594
  }
595
  OMP_THROW_EX();
596
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
597
  *out = ret.release();
598
  API_END();
599
600
}

Guolin Ke's avatar
Guolin Ke committed
601
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
602
603
604
605
606
607
608
609
610
611
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
612
  API_BEGIN();
wxchan's avatar
wxchan committed
613
  auto param = ConfigBase::Str2Map(parameters);
614
615
616
617
618
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
619
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
620
621
622
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
623
624
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
625
    auto sample_indices = rand.Sample(nrow, sample_cnt);
626
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
627
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
628
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
629
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
630
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
631
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
632
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
633
634
635
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
636
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
637
638
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
639
640
        }
      }
641
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
642
    }
643
    OMP_THROW_EX();
644
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
645
646
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
647
648
649
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
650
  } else {
651
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
652
    ret->CreateValid(
653
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
654
  }
655
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
656
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
657
  for (int i = 0; i < ncol_ptr - 1; ++i) {
658
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
659
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
660
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
661
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
662
663
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
664
665
666
667
668
669
670
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
671
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
672
    }
673
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
674
  }
675
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
676
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
677
  *out = ret.release();
678
  API_END();
Guolin Ke's avatar
Guolin Ke committed
679
680
}

Guolin Ke's avatar
Guolin Ke committed
681
int LGBM_DatasetGetSubset(
682
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
683
684
685
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
686
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
687
688
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
689
690
691
692
693
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
694
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
695
  CHECK(num_used_row_indices > 0);
696
697
698
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
Guolin Ke's avatar
Guolin Ke committed
699
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
700
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
701
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
702
703
704
705
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
706
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
707
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
708
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
709
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
710
711
712
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
713
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
714
715
716
717
718
719
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
720
int LGBM_DatasetGetFeatureNames(
721
722
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
723
  int* num_feature_names) {
724
725
726
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
727
728
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
729
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
730
731
732
733
  }
  API_END();
}

734
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
735
int LGBM_DatasetFree(DatasetHandle handle) {
736
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
737
  delete reinterpret_cast<Dataset*>(handle);
738
  API_END();
739
740
}

Guolin Ke's avatar
Guolin Ke committed
741
int LGBM_DatasetSaveBinary(DatasetHandle handle,
742
                           const char* filename) {
743
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
744
745
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
746
  API_END();
747
748
}

Guolin Ke's avatar
Guolin Ke committed
749
int LGBM_DatasetSetField(DatasetHandle handle,
750
751
752
753
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
754
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
755
  auto dataset = reinterpret_cast<Dataset*>(handle);
756
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
757
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
758
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
759
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
760
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
761
762
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
763
  }
764
765
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
766
767
}

Guolin Ke's avatar
Guolin Ke committed
768
int LGBM_DatasetGetField(DatasetHandle handle,
769
770
771
772
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
773
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
774
  auto dataset = reinterpret_cast<Dataset*>(handle);
775
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
776
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
777
    *out_type = C_API_DTYPE_FLOAT32;
778
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
779
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
780
    *out_type = C_API_DTYPE_INT32;
781
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
782
783
784
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
785
  }
786
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
787
  if (*out_ptr == nullptr) { *out_len = 0; }
788
  API_END();
789
790
}

Guolin Ke's avatar
Guolin Ke committed
791
int LGBM_DatasetGetNumData(DatasetHandle handle,
792
                           int* out) {
793
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
794
795
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
796
  API_END();
797
798
}

Guolin Ke's avatar
Guolin Ke committed
799
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
800
                              int* out) {
801
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
802
803
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
804
  API_END();
Guolin Ke's avatar
Guolin Ke committed
805
}
806
807
808

// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
809
int LGBM_BoosterCreate(const DatasetHandle train_data,
810
811
                       const char* parameters,
                       BoosterHandle* out) {
812
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
813
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
814
815
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
816
  API_END();
817
818
}

Guolin Ke's avatar
Guolin Ke committed
819
int LGBM_BoosterCreateFromModelfile(
820
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
821
  int* out_num_iterations,
822
  BoosterHandle* out) {
823
  API_BEGIN();
wxchan's avatar
wxchan committed
824
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
825
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
826
  *out = ret.release();
827
  API_END();
828
829
}

Guolin Ke's avatar
Guolin Ke committed
830
int LGBM_BoosterLoadModelFromString(
831
832
833
834
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
835
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
836
837
838
839
840
841
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

842
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
843
int LGBM_BoosterFree(BoosterHandle handle) {
844
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
845
  delete reinterpret_cast<Booster*>(handle);
846
  API_END();
847
848
}

Guolin Ke's avatar
Guolin Ke committed
849
int LGBM_BoosterMerge(BoosterHandle handle,
850
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
851
852
853
854
855
856
857
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
858
int LGBM_BoosterAddValidData(BoosterHandle handle,
859
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
860
861
862
863
864
865
866
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
867
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
868
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
869
870
871
872
873
874
875
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
876
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
877
878
879
880
881
882
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
883
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
884
885
886
887
888
889
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
890
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
891
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
892
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
893
894
895
896
897
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
898
  API_END();
899
900
}

Guolin Ke's avatar
Guolin Ke committed
901
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
902
903
904
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
905
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
906
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
907
908
909
910
911
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
912
  API_END();
913
914
}

Guolin Ke's avatar
Guolin Ke committed
915
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
916
917
918
919
920
921
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
922
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
923
924
925
926
927
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
928

Guolin Ke's avatar
Guolin Ke committed
929
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
930
931
932
933
934
935
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
936
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
937
938
939
940
941
942
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
943
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
944
945
946
947
948
949
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
950
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
951
952
953
954
955
956
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
957
int LGBM_BoosterGetEval(BoosterHandle handle,
958
959
960
                        int data_idx,
                        int* out_len,
                        double* out_results) {
961
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
962
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
963
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
964
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
965
  *out_len = static_cast<int>(result_buf.size());
966
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
967
    (out_results)[i] = static_cast<double>(result_buf[i]);
968
  }
969
  API_END();
970
971
}

Guolin Ke's avatar
Guolin Ke committed
972
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
973
974
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
975
976
977
978
979
980
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
981
int LGBM_BoosterGetPredict(BoosterHandle handle,
982
983
984
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
985
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
986
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
987
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
988
  API_END();
Guolin Ke's avatar
Guolin Ke committed
989
990
}

Guolin Ke's avatar
Guolin Ke committed
991
int LGBM_BoosterPredictForFile(BoosterHandle handle,
992
993
994
995
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
996
                               const char* parameter,
997
                               const char* result_filename) {
998
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
999
1000
1001
1002
1003
1004
  auto param = ConfigBase::Str2Map(parameter);
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1005
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1006
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1007
                       config.io_config, result_filename);
1008
  API_END();
1009
1010
}

Guolin Ke's avatar
Guolin Ke committed
1011
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1012
1013
1014
1015
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1016
1017
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1018
  *out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow(
1019
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB));
Guolin Ke's avatar
Guolin Ke committed
1020
1021
1022
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1023
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
1034
                              const char* parameter,
1035
1036
                              int64_t* out_len,
                              double* out_result) {
1037
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1038
1039
1040
1041
1042
1043
  auto param = ConfigBase::Str2Map(parameter);
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1044
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1045
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1046
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
1047
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1048
                       config.io_config, out_result, out_len);
1049
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1050
}
1051

Guolin Ke's avatar
Guolin Ke committed
1052
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1063
                              const char* parameter,
1064
1065
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1066
1067
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
  auto param = ConfigBase::Str2Map(parameter);
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1080
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1081
1082
1083
1084
1085
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1086
1087
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1088
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1089
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1090
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1091
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1092
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1093
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1094
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1095
1096
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1097
1098
    return one_row;
  };
Guolin Ke's avatar
Guolin Ke committed
1099
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config.io_config,
cbecker's avatar
cbecker committed
1100
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1101
1102
1103
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1104
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1105
1106
1107
1108
1109
1110
1111
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1112
                              const char* parameter,
1113
1114
                              int64_t* out_len,
                              double* out_result) {
1115
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1116
1117
1118
1119
1120
1121
  auto param = ConfigBase::Str2Map(parameter);
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1122
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1123
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1124
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1125
                       config.io_config, out_result, out_len);
1126
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1127
}
1128

Guolin Ke's avatar
Guolin Ke committed
1129
int LGBM_BoosterSaveModel(BoosterHandle handle,
1130
1131
                          int num_iteration,
                          const char* filename) {
1132
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1133
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1134
1135
1136
1137
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1138
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1139
                                  int num_iteration,
1140
1141
                                  int64_t buffer_len, 
                                  int64_t* out_len,
1142
                                  char* out_str) {
1143
1144
1145
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
1146
  *out_len = static_cast<int64_t>(model.size()) + 1;
1147
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1148
    std::memcpy(out_str, model.c_str(), *out_len);
1149
1150
1151
1152
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1153
int LGBM_BoosterDumpModel(BoosterHandle handle,
1154
                          int num_iteration,
1155
1156
                          int64_t buffer_len,
                          int64_t* out_len,
1157
                          char* out_str) {
wxchan's avatar
wxchan committed
1158
1159
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1160
  std::string model = ref_booster->DumpModel(num_iteration);
1161
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1162
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1163
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1164
  }
1165
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1166
}
1167

Guolin Ke's avatar
Guolin Ke committed
1168
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1169
1170
1171
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1172
1173
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1174
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1175
1176
1177
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1178
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1179
1180
1181
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1182
1183
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1184
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1185
1186
1187
  API_END();
}

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
  NetworkConfig config;
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1223
int LGBM_NetworkInitWithFunctions(void* reduce_scatter_fun_ptr,
1224
                                  void* allgather_fun_ptr,
Guolin Ke's avatar
Guolin Ke committed
1225
1226
                                  int num_machines,
                                  int rank) {
ww's avatar
ww committed
1227
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1228
  if (num_machines > 1) {
Guolin Ke's avatar
Guolin Ke committed
1229
1230
    Network::SetReduceScatterFunction((ReduceScatterFunction)reduce_scatter_fun_ptr);
    Network::SetAllgatherFunction((AllgatherFunction)allgather_fun_ptr);
ww's avatar
ww committed
1231
1232
1233
1234
1235
    Network::SetNumMachines(num_machines);
    Network::SetRank(rank);
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1236

Guolin Ke's avatar
Guolin Ke committed
1237
// ---- start of some help functions
1238
1239
1240

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1241
  if (data_type == C_API_DTYPE_FLOAT32) {
1242
1243
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1244
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1245
        std::vector<double> ret(num_col);
1246
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1247
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1248
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1249
1250
1251
1252
        }
        return ret;
      };
    } else {
1253
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1254
        std::vector<double> ret(num_col);
1255
        for (int i = 0; i < num_col; ++i) {
1256
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1257
1258
1259
1260
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1261
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1262
1263
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1264
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1265
        std::vector<double> ret(num_col);
1266
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1267
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1268
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1269
1270
1271
1272
        }
        return ret;
      };
    } else {
1273
      return [data_ptr, num_col, num_row] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1274
        std::vector<double> ret(num_col);
1275
        for (int i = 0; i < num_col; ++i) {
1276
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1277
1278
1279
1280
1281
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1282
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1283
1284
1285
1286
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1287
1288
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1289
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1290
1291
1292
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1293
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1294
          ret.emplace_back(i, raw_values[i]);
1295
        }
Guolin Ke's avatar
Guolin Ke committed
1296
1297
1298
      }
      return ret;
    };
1299
  }
Guolin Ke's avatar
Guolin Ke committed
1300
  return nullptr;
1301
1302
1303
1304
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1305
  if (data_type == C_API_DTYPE_FLOAT32) {
1306
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1307
    if (indptr_type == C_API_DTYPE_INT32) {
1308
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1309
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1310
1311
1312
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1313
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1314
          ret.emplace_back(indices[i], data_ptr[i]);
1315
1316
1317
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1318
    } else if (indptr_type == C_API_DTYPE_INT64) {
1319
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1320
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1321
1322
1323
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1324
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1325
          ret.emplace_back(indices[i], data_ptr[i]);
1326
1327
1328
1329
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1330
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1331
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1332
    if (indptr_type == C_API_DTYPE_INT32) {
1333
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1334
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1335
1336
1337
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1338
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1339
          ret.emplace_back(indices[i], data_ptr[i]);
1340
1341
1342
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1343
    } else if (indptr_type == C_API_DTYPE_INT64) {
1344
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1345
      return [ptr_indptr, indices, data_ptr, nindptr, nelem] (int idx) {
1346
1347
1348
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1349
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1350
          ret.emplace_back(indices[i], data_ptr[i]);
1351
1352
1353
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1354
1355
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1356
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1357
1358
}

Guolin Ke's avatar
Guolin Ke committed
1359
1360
1361
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1362
  if (data_type == C_API_DTYPE_FLOAT32) {
1363
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1364
    if (col_ptr_type == C_API_DTYPE_INT32) {
1365
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1366
1367
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1368
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1369
1370
1371
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1372
        }
Guolin Ke's avatar
Guolin Ke committed
1373
1374
1375
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1376
      };
Guolin Ke's avatar
Guolin Ke committed
1377
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1378
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1379
1380
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1381
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1382
1383
1384
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1385
        }
Guolin Ke's avatar
Guolin Ke committed
1386
1387
1388
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1389
      };
Guolin Ke's avatar
Guolin Ke committed
1390
    }
Guolin Ke's avatar
Guolin Ke committed
1391
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1392
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1393
    if (col_ptr_type == C_API_DTYPE_INT32) {
1394
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1395
1396
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1397
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1398
1399
1400
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1401
        }
Guolin Ke's avatar
Guolin Ke committed
1402
1403
1404
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1405
      };
Guolin Ke's avatar
Guolin Ke committed
1406
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1407
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1408
1409
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1410
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1411
1412
1413
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1414
        }
Guolin Ke's avatar
Guolin Ke committed
1415
1416
1417
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1418
      };
Guolin Ke's avatar
Guolin Ke committed
1419
1420
1421
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1422
1423
}

Guolin Ke's avatar
Guolin Ke committed
1424
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1425
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1426
1427
1428
1429
1430
1431
1432
1433
1434
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1435
    }
Guolin Ke's avatar
Guolin Ke committed
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1453
    }
Guolin Ke's avatar
Guolin Ke committed
1454
1455
1456
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1457
  }
Guolin Ke's avatar
Guolin Ke committed
1458
}