c_api.cpp 49.4 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>
cbecker's avatar
cbecker committed
13
#include <LightGBM/prediction_early_stop.h>
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
18

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
19
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
20
#include <stdexcept>
wxchan's avatar
wxchan committed
21
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
22
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
23

Guolin Ke's avatar
Guolin Ke committed
24
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
25
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
26

Guolin Ke's avatar
Guolin Ke committed
27
28
29
30
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
31
32
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
33
34
  }

35
36
37
38
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
39
  Booster(const Dataset* train_data,
40
          const char* parameters) {
wxchan's avatar
wxchan committed
41
42
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
43
44
45
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
46
47
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
48
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
49
50
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
51

wxchan's avatar
wxchan committed
52
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
53

Guolin Ke's avatar
Guolin Ke committed
54
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
55
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
56
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
57
58

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
59
60
61
62
63
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
64
65
66
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
67

Guolin Ke's avatar
Guolin Ke committed
68
  }
69

wxchan's avatar
wxchan committed
70
71
72
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
73
74
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
75
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
95
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
96
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
97
98
99
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
100
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
101
102
103
104
105
106
107
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
108
109
110
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
111
112

    config_.Set(param);
113
114
115
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
120
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
125
126
127
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
128
    }
Guolin Ke's avatar
Guolin Ke committed
129
130

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
131
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
132

wxchan's avatar
wxchan committed
133
134
135
136
137
138
139
140
141
142
143
144
145
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
146
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
147
  }
Guolin Ke's avatar
Guolin Ke committed
148

149
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
150
    std::lock_guard<std::mutex> lock(mutex_);
151
152
153
154
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
155
    std::lock_guard<std::mutex> lock(mutex_);
156
157
158
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
159
160
161
162
163
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
164
165
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
cbecker's avatar
cbecker committed
166
               const PredictionEarlyStoppingHandle early_stop_handle,
Guolin Ke's avatar
Guolin Ke committed
167
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
168
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
169
170
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
171
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
172
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
173
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
174
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
175
176
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
177
    }
cbecker's avatar
cbecker committed
178
179
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
                        reinterpret_cast<const PredictionEarlyStopInstance*>(early_stop_handle));
Guolin Ke's avatar
Guolin Ke committed
180
181
182
183
184
185
186
187
188
189
190
191
    int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
    auto pred_fun = predictor.GetPredictFunction();
    auto pred_wrt_ptr = out_result;
    for (int i = 0; i < nrow; ++i) {
      auto one_row = get_row_fun(i);
      pred_fun(one_row, pred_wrt_ptr);
      pred_wrt_ptr += num_preb_in_one_row;
    }
    *out_len = nrow * num_preb_in_one_row;
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
cbecker's avatar
cbecker committed
192
193
               int data_has_header, const PredictionEarlyStoppingHandle early_stop_handle,
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
194
195
196
197
198
199
200
201
202
203
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else {
      is_raw_score = false;
    }
cbecker's avatar
cbecker committed
204
205
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
                        reinterpret_cast<const PredictionEarlyStopInstance*>(early_stop_handle));
Guolin Ke's avatar
Guolin Ke committed
206
207
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
208
209
  }

Guolin Ke's avatar
Guolin Ke committed
210
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
211
212
213
214
215
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
216
  }
217

218
219
220
221
222
223
224
225
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

226
227
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
228
  }
229

Guolin Ke's avatar
Guolin Ke committed
230
231
232
233
234
235
236
237
238
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
239
240
241
242
243
244
245
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
246

wxchan's avatar
wxchan committed
247
248
249
250
251
252
253
254
255
256
257
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
258
259
260
261
262
263
264
265
266
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
267
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
268

Guolin Ke's avatar
Guolin Ke committed
269
private:
270

wxchan's avatar
wxchan committed
271
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
272
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
273
274
275
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
276
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
277
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
278
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
279
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
280
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
281
282
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
283
284
285
};

}
Guolin Ke's avatar
Guolin Ke committed
286
287
288

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
289
290
291
292
293
294
295
296
297
298
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
299
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
300
301
302
303
304

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
305
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
321
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
322
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
323
324
}

Guolin Ke's avatar
Guolin Ke committed
325
int LGBM_DatasetCreateFromFile(const char* filename,
326
327
328
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
329
  API_BEGIN();
wxchan's avatar
wxchan committed
330
  auto param = ConfigBase::Str2Map(parameters);
331
332
333
334
335
336
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  DatasetLoader loader(config.io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
337
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
338
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
339
  } else {
Guolin Ke's avatar
Guolin Ke committed
340
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
341
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
342
  }
343
  API_END();
Guolin Ke's avatar
Guolin Ke committed
344
345
}

346

Guolin Ke's avatar
Guolin Ke committed
347
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
348
349
350
351
352
353
354
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
355
356
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
357
358
359
360
361
362
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
363
364
365
366
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
367
368
}

369

Guolin Ke's avatar
Guolin Ke committed
370
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
371
372
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
373
374
375
376
377
378
379
380
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
381
int LGBM_DatasetPushRows(DatasetHandle dataset,
382
383
384
385
386
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
387
388
389
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
390
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
391
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
392
  for (int i = 0; i < nrow; ++i) {
393
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
394
395
396
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
397
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
398
  }
399
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
400
401
402
403
404
405
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
406
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
407
408
409
410
411
412
413
414
415
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
416
417
418
419
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
420
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
421
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
422
  for (int i = 0; i < nrow; ++i) {
423
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
424
425
426
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
427
                          static_cast<data_size_t>(start_row + i), one_row);
428
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
429
  }
430
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
431
432
433
434
435
436
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
437
int LGBM_DatasetCreateFromMat(const void* data,
438
439
440
441
442
443
444
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
445
  API_BEGIN();
wxchan's avatar
wxchan committed
446
  auto param = ConfigBase::Str2Map(parameters);
447
448
449
450
451
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
452
  std::unique_ptr<Dataset> ret;
453
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
454
455
  if (reference == nullptr) {
    // sample data first
456
457
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
458
    auto sample_indices = rand.Sample(nrow, sample_cnt);
459
    sample_cnt = static_cast<int>(sample_indices.size());
460
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
461
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
462
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
463
      auto idx = sample_indices[i];
464
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
465
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
466
467
468
        if (std::fabs(row[j]) > kEpsilon) {
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
469
        }
Guolin Ke's avatar
Guolin Ke committed
470
471
      }
    }
472
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
473
474
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
475
476
477
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
478
  } else {
479
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
480
    ret->CreateValid(
481
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
482
  }
483
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
484
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
485
  for (int i = 0; i < nrow; ++i) {
486
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
487
    const int tid = omp_get_thread_num();
488
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
489
    ret->PushOneRow(tid, i, one_row);
490
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
491
  }
492
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
493
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
494
  *out = ret.release();
495
  API_END();
496
497
}

Guolin Ke's avatar
Guolin Ke committed
498
int LGBM_DatasetCreateFromCSR(const void* indptr,
499
500
501
502
503
504
505
506
507
508
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
509
  API_BEGIN();
wxchan's avatar
wxchan committed
510
  auto param = ConfigBase::Str2Map(parameters);
511
512
513
514
515
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
516
  std::unique_ptr<Dataset> ret;
517
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
518
519
520
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
521
522
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
523
    auto sample_indices = rand.Sample(nrow, sample_cnt);
524
    sample_cnt = static_cast<int>(sample_indices.size());
525
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
526
    std::vector<std::vector<int>> sample_idx;
527
528
529
530
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
531
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
532
533
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
534
        }
Guolin Ke's avatar
Guolin Ke committed
535
536
537
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
538
539
540
        }
      }
    }
541
    CHECK(num_col >= static_cast<int>(sample_values.size()));
542
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
543
544
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
545
546
547
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
548
  } else {
549
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
550
    ret->CreateValid(
551
      reinterpret_cast<const Dataset*>(reference));
552
  }
553
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
554
  #pragma omp parallel for schedule(static)
555
  for (int i = 0; i < nindptr - 1; ++i) {
556
    OMP_LOOP_EX_BEGIN();
557
558
559
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
560
    OMP_LOOP_EX_END();
561
  }
562
  OMP_THROW_EX();
563
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
564
  *out = ret.release();
565
  API_END();
566
567
}

Guolin Ke's avatar
Guolin Ke committed
568
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
569
570
571
572
573
574
575
576
577
578
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
579
  API_BEGIN();
wxchan's avatar
wxchan committed
580
  auto param = ConfigBase::Str2Map(parameters);
581
582
583
584
585
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
586
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
587
588
589
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
590
591
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
592
    auto sample_indices = rand.Sample(nrow, sample_cnt);
593
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
594
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
595
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
596
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
597
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
598
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
599
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
600
601
602
603
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
Guolin Ke's avatar
Guolin Ke committed
604
605
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
606
607
        }
      }
608
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
609
    }
610
    OMP_THROW_EX();
611
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
612
613
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
614
615
616
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
617
  } else {
618
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
619
    ret->CreateValid(
620
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
621
  }
622
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
623
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
624
  for (int i = 0; i < ncol_ptr - 1; ++i) {
625
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
626
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
627
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
628
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
629
630
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
631
632
633
634
635
636
637
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
638
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
639
    }
640
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
641
  }
642
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
643
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
644
  *out = ret.release();
645
  API_END();
Guolin Ke's avatar
Guolin Ke committed
646
647
}

Guolin Ke's avatar
Guolin Ke committed
648
int LGBM_DatasetGetSubset(
649
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
650
651
652
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
653
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
654
655
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
656
657
658
659
660
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
661
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
662
  CHECK(num_used_row_indices > 0);
Guolin Ke's avatar
Guolin Ke committed
663
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
664
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
665
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
666
667
668
669
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
670
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
671
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
672
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
673
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
674
675
676
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
677
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
678
679
680
681
682
683
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
684
int LGBM_DatasetGetFeatureNames(
685
686
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
687
  int* num_feature_names) {
688
689
690
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
691
692
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
693
694
695
696
697
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
698
int LGBM_DatasetFree(DatasetHandle handle) {
699
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
700
  delete reinterpret_cast<Dataset*>(handle);
701
  API_END();
702
703
}

Guolin Ke's avatar
Guolin Ke committed
704
int LGBM_DatasetSaveBinary(DatasetHandle handle,
705
                           const char* filename) {
706
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
707
708
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
709
  API_END();
710
711
}

Guolin Ke's avatar
Guolin Ke committed
712
int LGBM_DatasetSetField(DatasetHandle handle,
713
714
715
716
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
717
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
718
  auto dataset = reinterpret_cast<Dataset*>(handle);
719
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
720
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
721
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
722
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
723
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
724
725
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
726
  }
727
728
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
729
730
}

Guolin Ke's avatar
Guolin Ke committed
731
int LGBM_DatasetGetField(DatasetHandle handle,
732
733
734
735
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
736
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
737
  auto dataset = reinterpret_cast<Dataset*>(handle);
738
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
739
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
740
    *out_type = C_API_DTYPE_FLOAT32;
741
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
742
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
743
    *out_type = C_API_DTYPE_INT32;
744
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
745
746
747
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
748
  }
749
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
750
  if (*out_ptr == nullptr) { *out_len = 0; }
751
  API_END();
752
753
}

Guolin Ke's avatar
Guolin Ke committed
754
int LGBM_DatasetGetNumData(DatasetHandle handle,
755
                           int* out) {
756
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
757
758
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
759
  API_END();
760
761
}

Guolin Ke's avatar
Guolin Ke committed
762
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
763
                              int* out) {
764
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
765
766
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
767
  API_END();
Guolin Ke's avatar
Guolin Ke committed
768
}
769
770
771

// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
772
int LGBM_BoosterCreate(const DatasetHandle train_data,
773
774
                       const char* parameters,
                       BoosterHandle* out) {
775
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
776
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
777
778
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
779
  API_END();
780
781
}

Guolin Ke's avatar
Guolin Ke committed
782
int LGBM_BoosterCreateFromModelfile(
783
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
784
  int* out_num_iterations,
785
  BoosterHandle* out) {
786
  API_BEGIN();
wxchan's avatar
wxchan committed
787
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
788
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
789
  *out = ret.release();
790
  API_END();
791
792
}

Guolin Ke's avatar
Guolin Ke committed
793
int LGBM_BoosterLoadModelFromString(
794
795
796
797
798
799
800
801
802
803
804
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
805
int LGBM_BoosterFree(BoosterHandle handle) {
806
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
807
  delete reinterpret_cast<Booster*>(handle);
808
  API_END();
809
810
}

Guolin Ke's avatar
Guolin Ke committed
811
int LGBM_BoosterMerge(BoosterHandle handle,
812
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
813
814
815
816
817
818
819
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
820
int LGBM_BoosterAddValidData(BoosterHandle handle,
821
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
822
823
824
825
826
827
828
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
829
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
830
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
831
832
833
834
835
836
837
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
838
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
839
840
841
842
843
844
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
845
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
846
847
848
849
850
851
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
852
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
853
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
854
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
855
856
857
858
859
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
860
  API_END();
861
862
}

Guolin Ke's avatar
Guolin Ke committed
863
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
864
865
866
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
867
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
868
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
869
870
871
872
873
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
874
  API_END();
875
876
}

Guolin Ke's avatar
Guolin Ke committed
877
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
878
879
880
881
882
883
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
884
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
885
886
887
888
889
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
890

Guolin Ke's avatar
Guolin Ke committed
891
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
892
893
894
895
896
897
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
898
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
899
900
901
902
903
904
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
905
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
906
907
908
909
910
911
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
912
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
913
914
915
916
917
918
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
919
int LGBM_BoosterGetEval(BoosterHandle handle,
920
921
922
                        int data_idx,
                        int* out_len,
                        double* out_results) {
923
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
924
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
925
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
926
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
927
  *out_len = static_cast<int>(result_buf.size());
928
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
929
    (out_results)[i] = static_cast<double>(result_buf[i]);
930
  }
931
  API_END();
932
933
}

Guolin Ke's avatar
Guolin Ke committed
934
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
935
936
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
937
938
939
940
941
942
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
943
int LGBM_BoosterGetPredict(BoosterHandle handle,
944
945
946
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
947
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
948
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
949
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
950
  API_END();
Guolin Ke's avatar
Guolin Ke committed
951
952
}

Guolin Ke's avatar
Guolin Ke committed
953
int LGBM_BoosterPredictForFile(BoosterHandle handle,
954
955
956
957
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
cbecker's avatar
cbecker committed
958
                               const PredictionEarlyStoppingHandle early_stop_handle,
959
                               const char* result_filename) {
960
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
961
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
962
963
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
                       early_stop_handle, result_filename);
964
  API_END();
965
966
}

Guolin Ke's avatar
Guolin Ke committed
967
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
968
969
970
971
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
972
973
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
974
975
  *out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX));
Guolin Ke's avatar
Guolin Ke committed
976
977
978
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
979
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
980
981
982
983
984
985
986
987
988
989
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
cbecker's avatar
cbecker committed
990
                              const PredictionEarlyStoppingHandle early_stop_handle,
991
992
                              int64_t* out_len,
                              double* out_result) {
993
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
994
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
995
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
996
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
997
998
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
                       early_stop_handle, out_result, out_len);
999
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1000
}
1001

Guolin Ke's avatar
Guolin Ke committed
1002
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
cbecker's avatar
cbecker committed
1013
                              const PredictionEarlyStoppingHandle early_stop_handle,
1014
1015
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1016
1017
1018
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1019
1020
1021
1022
1023
1024
  std::vector<CSC_RowIterator> iterators;
  for (int j = 0; j < ncol; ++j) {
    iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
    [&iterators, ncol](int i) {
Guolin Ke's avatar
Guolin Ke committed
1025
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1026
1027
1028
1029
    for (int j = 0; j < ncol; ++j) {
      auto val = iterators[j].Get(i);
      if (std::fabs(val) > kEpsilon) {
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1030
1031
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1032
1033
    return one_row;
  };
cbecker's avatar
cbecker committed
1034
1035
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, early_stop_handle,
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1036
1037
1038
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1039
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1040
1041
1042
1043
1044
1045
1046
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
cbecker's avatar
cbecker committed
1047
                              const PredictionEarlyStoppingHandle early_stop_handle,
1048
1049
                              int64_t* out_len,
                              double* out_result) {
1050
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1051
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1052
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1053
1054
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
                      early_stop_handle, out_result, out_len);
1055
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1056
}
1057

Guolin Ke's avatar
Guolin Ke committed
1058
int LGBM_BoosterSaveModel(BoosterHandle handle,
1059
1060
                          int num_iteration,
                          const char* filename) {
1061
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1062
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1063
1064
1065
1066
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1067
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1068
1069
1070
1071
                                  int num_iteration,
                                  int buffer_len,
                                  int* out_len,
                                  char* out_str) {
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1082
int LGBM_BoosterDumpModel(BoosterHandle handle,
1083
1084
1085
1086
                          int num_iteration,
                          int buffer_len,
                          int* out_len,
                          char* out_str) {
wxchan's avatar
wxchan committed
1087
1088
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1089
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1090
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
1091
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1092
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
1093
  }
1094
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1095
}
1096

Guolin Ke's avatar
Guolin Ke committed
1097
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1098
1099
1100
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1101
1102
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1103
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1104
1105
1106
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1107
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1108
1109
1110
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1111
1112
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1113
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1114
1115
1116
  API_END();
}

cbecker's avatar
cbecker committed
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141

int LGBM_PredictionEarlyStopInstanceCreate(const char* type,
                                         int   round_period,
                                         double margin_threshold,
                                         PredictionEarlyStoppingHandle* out)
{
  API_BEGIN();
  PredictionEarlyStopConfig config;
  config.marginThreshold = margin_threshold;
  config.roundPeriod = round_period;

  auto earlyStop = createPredictionEarlyStopInstance(type, config);

    // create new by copying
  *out = new PredictionEarlyStopInstance(earlyStop);
  API_END();
}

int LGBM_PredictionEarlyStopInstanceFree(const PredictionEarlyStoppingHandle handle)
{
  API_BEGIN();
  delete reinterpret_cast<const PredictionEarlyStopInstance*>(handle);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1142
// ---- start of some help functions
1143
1144
1145

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1146
  if (data_type == C_API_DTYPE_FLOAT32) {
1147
1148
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1149
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1150
        std::vector<double> ret(num_col);
1151
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1152
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1153
          ret[i] = static_cast<double>(*(tmp_ptr + i));
Guolin Ke's avatar
Guolin Ke committed
1154
1155
1156
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1157
1158
1159
1160
        }
        return ret;
      };
    } else {
Guolin Ke's avatar
Guolin Ke committed
1161
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1162
        std::vector<double> ret(num_col);
1163
        for (int i = 0; i < num_col; ++i) {
1164
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
Guolin Ke's avatar
Guolin Ke committed
1165
1166
1167
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1168
1169
1170
1171
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1172
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1173
1174
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1175
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1176
        std::vector<double> ret(num_col);
1177
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1178
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1179
          ret[i] = static_cast<double>(*(tmp_ptr + i));
Guolin Ke's avatar
Guolin Ke committed
1180
1181
1182
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1183
1184
1185
1186
        }
        return ret;
      };
    } else {
Guolin Ke's avatar
Guolin Ke committed
1187
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1188
        std::vector<double> ret(num_col);
1189
        for (int i = 0; i < num_col; ++i) {
1190
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
Guolin Ke's avatar
Guolin Ke committed
1191
1192
1193
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1194
1195
1196
1197
1198
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1199
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1200
1201
1202
1203
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1204
1205
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
1206
    return [inner_function](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1207
1208
1209
1210
1211
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1212
        }
Guolin Ke's avatar
Guolin Ke committed
1213
1214
1215
      }
      return ret;
    };
1216
  }
Guolin Ke's avatar
Guolin Ke committed
1217
  return nullptr;
1218
1219
1220
1221
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1222
  if (data_type == C_API_DTYPE_FLOAT32) {
1223
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1224
    if (indptr_type == C_API_DTYPE_INT32) {
1225
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1226
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1227
1228
1229
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1230
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1231
1232
1233
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1234
1235
1236
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1237
    } else if (indptr_type == C_API_DTYPE_INT64) {
1238
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1239
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1240
1241
1242
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1243
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1244
1245
1246
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1247
1248
1249
1250
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1251
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1252
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1253
    if (indptr_type == C_API_DTYPE_INT32) {
1254
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1255
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1256
1257
1258
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1259
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1260
1261
1262
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1263
1264
1265
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1266
    } else if (indptr_type == C_API_DTYPE_INT64) {
1267
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1268
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1269
1270
1271
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1272
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1273
1274
1275
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1276
1277
1278
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1279
1280
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1281
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1282
1283
}

Guolin Ke's avatar
Guolin Ke committed
1284
1285
1286
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1287
  if (data_type == C_API_DTYPE_FLOAT32) {
1288
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1289
    if (col_ptr_type == C_API_DTYPE_INT32) {
1290
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1291
1292
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1293
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1294
1295
1296
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1297
        }
Guolin Ke's avatar
Guolin Ke committed
1298
1299
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1300
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1301
        return std::make_pair(idx, val);
1302
      };
Guolin Ke's avatar
Guolin Ke committed
1303
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1304
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1305
1306
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1307
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1308
1309
1310
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1311
        }
Guolin Ke's avatar
Guolin Ke committed
1312
1313
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1314
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1315
        return std::make_pair(idx, val);
1316
      };
Guolin Ke's avatar
Guolin Ke committed
1317
    }
Guolin Ke's avatar
Guolin Ke committed
1318
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1319
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1320
    if (col_ptr_type == C_API_DTYPE_INT32) {
1321
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1322
1323
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1324
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1325
1326
1327
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1328
        }
Guolin Ke's avatar
Guolin Ke committed
1329
1330
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1331
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1332
        return std::make_pair(idx, val);
1333
      };
Guolin Ke's avatar
Guolin Ke committed
1334
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1335
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1336
1337
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1338
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1339
1340
1341
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1342
        }
Guolin Ke's avatar
Guolin Ke committed
1343
1344
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1345
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1346
        return std::make_pair(idx, val);
1347
      };
Guolin Ke's avatar
Guolin Ke committed
1348
1349
1350
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1351
1352
}

Guolin Ke's avatar
Guolin Ke committed
1353
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1354
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1355
1356
1357
1358
1359
1360
1361
1362
1363
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1364
    }
Guolin Ke's avatar
Guolin Ke committed
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1382
    }
Guolin Ke's avatar
Guolin Ke committed
1383
1384
1385
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1386
  }
Guolin Ke's avatar
Guolin Ke committed
1387
}