c_api.cpp 47.8 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
39
          const char* parameters) {
wxchan's avatar
wxchan committed
40
41
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
55
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
74
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
95
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
96
97
98
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
119
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
130
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
145
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
146
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
164
165
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
166
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
167
168
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
169
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
170
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
171
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
172
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
173
174
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
175
    }
Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf);
    int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
    auto pred_fun = predictor.GetPredictFunction();
    auto pred_wrt_ptr = out_result;
    for (int i = 0; i < nrow; ++i) {
      auto one_row = get_row_fun(i);
      pred_fun(one_row, pred_wrt_ptr);
      pred_wrt_ptr += num_preb_in_one_row;
    }
    *out_len = nrow * num_preb_in_one_row;
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
               int data_has_header, const char* result_filename) {
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else {
      is_raw_score = false;
    }
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf);
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
203
204
  }

Guolin Ke's avatar
Guolin Ke committed
205
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
206
207
208
209
210
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
211
  }
212

213
214
215
216
217
218
219
220
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

221
222
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
223
  }
224

Guolin Ke's avatar
Guolin Ke committed
225
226
227
228
229
230
231
232
233
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
234
235
236
237
238
239
240
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
241

wxchan's avatar
wxchan committed
242
243
244
245
246
247
248
249
250
251
252
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
253
254
255
256
257
258
259
260
261
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
262
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
263

Guolin Ke's avatar
Guolin Ke committed
264
private:
265

wxchan's avatar
wxchan committed
266
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
267
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
268
269
270
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
271
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
272
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
273
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
274
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
275
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
276
277
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
278
279
280
};

}
Guolin Ke's avatar
Guolin Ke committed
281
282
283

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
288
289
290
291
292
293
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
294
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
295
296
297
298
299

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
300
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
316
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
317
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
318
319
}

Guolin Ke's avatar
Guolin Ke committed
320
int LGBM_DatasetCreateFromFile(const char* filename,
321
322
323
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
324
  API_BEGIN();
wxchan's avatar
wxchan committed
325
  auto param = ConfigBase::Str2Map(parameters);
326
327
328
329
330
331
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  DatasetLoader loader(config.io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
332
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
333
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
334
  } else {
Guolin Ke's avatar
Guolin Ke committed
335
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
336
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
337
  }
338
  API_END();
Guolin Ke's avatar
Guolin Ke committed
339
340
}

341

Guolin Ke's avatar
Guolin Ke committed
342
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
343
344
345
346
347
348
349
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
350
351
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
352
353
354
355
356
357
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
358
359
360
361
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
362
363
}

364

Guolin Ke's avatar
Guolin Ke committed
365
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
366
367
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
368
369
370
371
372
373
374
375
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
376
int LGBM_DatasetPushRows(DatasetHandle dataset,
377
378
379
380
381
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
382
383
384
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
385
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
386
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
387
  for (int i = 0; i < nrow; ++i) {
388
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
389
390
391
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
392
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
393
  }
394
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
395
396
397
398
399
400
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
401
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
402
403
404
405
406
407
408
409
410
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
411
412
413
414
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
415
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
416
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
417
  for (int i = 0; i < nrow; ++i) {
418
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
419
420
421
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
422
                          static_cast<data_size_t>(start_row + i), one_row);
423
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
424
  }
425
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
426
427
428
429
430
431
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
432
int LGBM_DatasetCreateFromMat(const void* data,
433
434
435
436
437
438
439
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
440
  API_BEGIN();
wxchan's avatar
wxchan committed
441
  auto param = ConfigBase::Str2Map(parameters);
442
443
444
445
446
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
447
  std::unique_ptr<Dataset> ret;
448
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
449
450
  if (reference == nullptr) {
    // sample data first
451
452
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
453
    auto sample_indices = rand.Sample(nrow, sample_cnt);
454
    sample_cnt = static_cast<int>(sample_indices.size());
455
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
456
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
457
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
458
      auto idx = sample_indices[i];
459
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
460
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
461
462
463
        if (std::fabs(row[j]) > kEpsilon) {
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
464
        }
Guolin Ke's avatar
Guolin Ke committed
465
466
      }
    }
467
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
468
469
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
470
471
472
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
473
  } else {
474
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
475
    ret->CreateValid(
476
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
477
  }
478
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
479
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
480
  for (int i = 0; i < nrow; ++i) {
481
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
482
    const int tid = omp_get_thread_num();
483
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
484
    ret->PushOneRow(tid, i, one_row);
485
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
486
  }
487
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
488
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
489
  *out = ret.release();
490
  API_END();
491
492
}

Guolin Ke's avatar
Guolin Ke committed
493
int LGBM_DatasetCreateFromCSR(const void* indptr,
494
495
496
497
498
499
500
501
502
503
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
504
  API_BEGIN();
wxchan's avatar
wxchan committed
505
  auto param = ConfigBase::Str2Map(parameters);
506
507
508
509
510
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
511
  std::unique_ptr<Dataset> ret;
512
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
513
514
515
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
516
517
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
518
    auto sample_indices = rand.Sample(nrow, sample_cnt);
519
    sample_cnt = static_cast<int>(sample_indices.size());
520
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
521
    std::vector<std::vector<int>> sample_idx;
522
523
524
525
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
526
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
527
528
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
529
        }
Guolin Ke's avatar
Guolin Ke committed
530
531
532
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
533
534
535
        }
      }
    }
536
    CHECK(num_col >= static_cast<int>(sample_values.size()));
537
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
538
539
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
540
541
542
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
543
  } else {
544
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
545
    ret->CreateValid(
546
      reinterpret_cast<const Dataset*>(reference));
547
  }
548
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
549
  #pragma omp parallel for schedule(static)
550
  for (int i = 0; i < nindptr - 1; ++i) {
551
    OMP_LOOP_EX_BEGIN();
552
553
554
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
555
    OMP_LOOP_EX_END();
556
  }
557
  OMP_THROW_EX();
558
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
559
  *out = ret.release();
560
  API_END();
561
562
}

Guolin Ke's avatar
Guolin Ke committed
563
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
564
565
566
567
568
569
570
571
572
573
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
574
  API_BEGIN();
wxchan's avatar
wxchan committed
575
  auto param = ConfigBase::Str2Map(parameters);
576
577
578
579
580
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
581
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
582
583
584
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
585
586
    Random rand(config.io_config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
587
    auto sample_indices = rand.Sample(nrow, sample_cnt);
588
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
589
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
590
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
591
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
592
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
593
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
594
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
595
596
597
598
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
Guolin Ke's avatar
Guolin Ke committed
599
600
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
601
602
        }
      }
603
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
604
    }
605
    OMP_THROW_EX();
606
    DatasetLoader loader(config.io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
607
608
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
609
610
611
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
612
  } else {
613
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
614
    ret->CreateValid(
615
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
616
  }
617
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
618
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
619
  for (int i = 0; i < ncol_ptr - 1; ++i) {
620
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
621
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
622
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
623
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
624
625
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
626
627
628
629
630
631
632
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
633
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
634
    }
635
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
636
  }
637
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
638
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
639
  *out = ret.release();
640
  API_END();
Guolin Ke's avatar
Guolin Ke committed
641
642
}

Guolin Ke's avatar
Guolin Ke committed
643
int LGBM_DatasetGetSubset(
644
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
645
646
647
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
648
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
649
650
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
651
652
653
654
655
  OverallConfig config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
656
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
657
  CHECK(num_used_row_indices > 0);
Guolin Ke's avatar
Guolin Ke committed
658
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
659
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
660
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
661
662
663
664
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
665
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
666
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
667
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
668
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
669
670
671
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
672
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
673
674
675
676
677
678
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
679
int LGBM_DatasetGetFeatureNames(
680
681
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
682
  int* num_feature_names) {
683
684
685
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
686
687
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
688
689
690
691
692
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
693
int LGBM_DatasetFree(DatasetHandle handle) {
694
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
695
  delete reinterpret_cast<Dataset*>(handle);
696
  API_END();
697
698
}

Guolin Ke's avatar
Guolin Ke committed
699
int LGBM_DatasetSaveBinary(DatasetHandle handle,
700
                           const char* filename) {
701
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
702
703
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
704
  API_END();
705
706
}

Guolin Ke's avatar
Guolin Ke committed
707
int LGBM_DatasetSetField(DatasetHandle handle,
708
709
710
711
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
712
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
713
  auto dataset = reinterpret_cast<Dataset*>(handle);
714
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
715
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
716
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
717
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
718
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
719
720
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
721
  }
722
723
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
724
725
}

Guolin Ke's avatar
Guolin Ke committed
726
int LGBM_DatasetGetField(DatasetHandle handle,
727
728
729
730
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
731
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
732
  auto dataset = reinterpret_cast<Dataset*>(handle);
733
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
734
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
735
    *out_type = C_API_DTYPE_FLOAT32;
736
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
737
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
738
    *out_type = C_API_DTYPE_INT32;
739
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
740
741
742
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
743
  }
744
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
745
  if (*out_ptr == nullptr) { *out_len = 0; }
746
  API_END();
747
748
}

Guolin Ke's avatar
Guolin Ke committed
749
int LGBM_DatasetGetNumData(DatasetHandle handle,
750
                           int* out) {
751
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
752
753
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
754
  API_END();
755
756
}

Guolin Ke's avatar
Guolin Ke committed
757
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
758
                              int* out) {
759
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
760
761
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
762
  API_END();
Guolin Ke's avatar
Guolin Ke committed
763
}
764
765
766

// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
767
int LGBM_BoosterCreate(const DatasetHandle train_data,
768
769
                       const char* parameters,
                       BoosterHandle* out) {
770
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
771
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
772
773
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
774
  API_END();
775
776
}

Guolin Ke's avatar
Guolin Ke committed
777
int LGBM_BoosterCreateFromModelfile(
778
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
779
  int* out_num_iterations,
780
  BoosterHandle* out) {
781
  API_BEGIN();
wxchan's avatar
wxchan committed
782
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
783
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
784
  *out = ret.release();
785
  API_END();
786
787
}

Guolin Ke's avatar
Guolin Ke committed
788
int LGBM_BoosterLoadModelFromString(
789
790
791
792
793
794
795
796
797
798
799
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
800
int LGBM_BoosterFree(BoosterHandle handle) {
801
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
802
  delete reinterpret_cast<Booster*>(handle);
803
  API_END();
804
805
}

Guolin Ke's avatar
Guolin Ke committed
806
int LGBM_BoosterMerge(BoosterHandle handle,
807
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
808
809
810
811
812
813
814
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
815
int LGBM_BoosterAddValidData(BoosterHandle handle,
816
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
817
818
819
820
821
822
823
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
824
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
825
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
826
827
828
829
830
831
832
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
833
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
834
835
836
837
838
839
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
840
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
841
842
843
844
845
846
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
847
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
848
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
849
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
850
851
852
853
854
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
855
  API_END();
856
857
}

Guolin Ke's avatar
Guolin Ke committed
858
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
859
860
861
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
862
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
863
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
864
865
866
867
868
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
869
  API_END();
870
871
}

Guolin Ke's avatar
Guolin Ke committed
872
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
873
874
875
876
877
878
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
879
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
880
881
882
883
884
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
885

Guolin Ke's avatar
Guolin Ke committed
886
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
887
888
889
890
891
892
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
893
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
894
895
896
897
898
899
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
900
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
901
902
903
904
905
906
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
907
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
908
909
910
911
912
913
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
914
int LGBM_BoosterGetEval(BoosterHandle handle,
915
916
917
                        int data_idx,
                        int* out_len,
                        double* out_results) {
918
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
919
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
920
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
921
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
922
  *out_len = static_cast<int>(result_buf.size());
923
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
924
    (out_results)[i] = static_cast<double>(result_buf[i]);
925
  }
926
  API_END();
927
928
}

Guolin Ke's avatar
Guolin Ke committed
929
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
930
931
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
932
933
934
935
936
937
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
938
int LGBM_BoosterGetPredict(BoosterHandle handle,
939
940
941
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
942
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
943
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
944
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
945
  API_END();
Guolin Ke's avatar
Guolin Ke committed
946
947
}

Guolin Ke's avatar
Guolin Ke committed
948
int LGBM_BoosterPredictForFile(BoosterHandle handle,
949
950
951
952
953
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
                               const char* result_filename) {
954
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
955
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
956
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header, result_filename);
957
  API_END();
958
959
}

Guolin Ke's avatar
Guolin Ke committed
960
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
961
962
963
964
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
965
966
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
967
968
  *out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX));
Guolin Ke's avatar
Guolin Ke committed
969
970
971
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
972
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
973
974
975
976
977
978
979
980
981
982
983
984
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
                              int64_t* out_len,
                              double* out_result) {
985
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
986
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
987
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
988
  int nrow = static_cast<int>(nindptr - 1);
Guolin Ke's avatar
Guolin Ke committed
989
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len);
990
  API_END();
Guolin Ke's avatar
Guolin Ke committed
991
}
992

Guolin Ke's avatar
Guolin Ke committed
993
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1006
1007
1008
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1009
1010
1011
1012
1013
1014
  std::vector<CSC_RowIterator> iterators;
  for (int j = 0; j < ncol; ++j) {
    iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
    [&iterators, ncol](int i) {
Guolin Ke's avatar
Guolin Ke committed
1015
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1016
1017
1018
1019
    for (int j = 0; j < ncol; ++j) {
      auto val = iterators[j].Get(i);
      if (std::fabs(val) > kEpsilon) {
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1020
1021
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1022
1023
1024
    return one_row;
  };
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1025
1026
1027
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1028
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1029
1030
1031
1032
1033
1034
1035
1036
1037
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
                              int64_t* out_len,
                              double* out_result) {
1038
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1039
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1040
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
1041
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len);
1042
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1043
}
1044

Guolin Ke's avatar
Guolin Ke committed
1045
int LGBM_BoosterSaveModel(BoosterHandle handle,
1046
1047
                          int num_iteration,
                          const char* filename) {
1048
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1049
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1050
1051
1052
1053
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1054
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1055
1056
1057
1058
                                  int num_iteration,
                                  int buffer_len,
                                  int* out_len,
                                  char* out_str) {
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1069
int LGBM_BoosterDumpModel(BoosterHandle handle,
1070
1071
1072
1073
                          int num_iteration,
                          int buffer_len,
                          int* out_len,
                          char* out_str) {
wxchan's avatar
wxchan committed
1074
1075
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1076
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1077
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
1078
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1079
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
1080
  }
1081
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1082
}
1083

Guolin Ke's avatar
Guolin Ke committed
1084
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1085
1086
1087
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1088
1089
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1090
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1091
1092
1093
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1094
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1095
1096
1097
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1098
1099
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1100
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1101
1102
1103
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1104
// ---- start of some help functions
1105
1106
1107

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1108
  if (data_type == C_API_DTYPE_FLOAT32) {
1109
1110
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1111
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1112
        std::vector<double> ret(num_col);
1113
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1114
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1115
          ret[i] = static_cast<double>(*(tmp_ptr + i));
Guolin Ke's avatar
Guolin Ke committed
1116
1117
1118
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1119
1120
1121
1122
        }
        return ret;
      };
    } else {
Guolin Ke's avatar
Guolin Ke committed
1123
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1124
        std::vector<double> ret(num_col);
1125
        for (int i = 0; i < num_col; ++i) {
1126
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
Guolin Ke's avatar
Guolin Ke committed
1127
1128
1129
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1130
1131
1132
1133
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1134
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1135
1136
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1137
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1138
        std::vector<double> ret(num_col);
1139
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1140
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1141
          ret[i] = static_cast<double>(*(tmp_ptr + i));
Guolin Ke's avatar
Guolin Ke committed
1142
1143
1144
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1145
1146
1147
1148
        }
        return ret;
      };
    } else {
Guolin Ke's avatar
Guolin Ke committed
1149
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1150
        std::vector<double> ret(num_col);
1151
        for (int i = 0; i < num_col; ++i) {
1152
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
Guolin Ke's avatar
Guolin Ke committed
1153
1154
1155
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1156
1157
1158
1159
1160
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1161
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1162
1163
1164
1165
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1166
1167
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
1168
    return [inner_function](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1169
1170
1171
1172
1173
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1174
        }
Guolin Ke's avatar
Guolin Ke committed
1175
1176
1177
      }
      return ret;
    };
1178
  }
Guolin Ke's avatar
Guolin Ke committed
1179
  return nullptr;
1180
1181
1182
1183
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1184
  if (data_type == C_API_DTYPE_FLOAT32) {
1185
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1186
    if (indptr_type == C_API_DTYPE_INT32) {
1187
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1188
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1189
1190
1191
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1192
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1193
1194
1195
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1196
1197
1198
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1199
    } else if (indptr_type == C_API_DTYPE_INT64) {
1200
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1201
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1202
1203
1204
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1205
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1206
1207
1208
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1209
1210
1211
1212
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1213
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1214
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1215
    if (indptr_type == C_API_DTYPE_INT32) {
1216
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1217
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1218
1219
1220
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1221
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1222
1223
1224
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1225
1226
1227
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1228
    } else if (indptr_type == C_API_DTYPE_INT64) {
1229
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1230
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1231
1232
1233
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1234
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1235
1236
1237
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1238
1239
1240
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1241
1242
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1243
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1244
1245
}

Guolin Ke's avatar
Guolin Ke committed
1246
1247
1248
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1249
  if (data_type == C_API_DTYPE_FLOAT32) {
1250
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1251
    if (col_ptr_type == C_API_DTYPE_INT32) {
1252
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1253
1254
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1255
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1256
1257
1258
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1259
        }
Guolin Ke's avatar
Guolin Ke committed
1260
1261
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1262
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1263
        return std::make_pair(idx, val);
1264
      };
Guolin Ke's avatar
Guolin Ke committed
1265
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1266
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1267
1268
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1269
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1270
1271
1272
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1273
        }
Guolin Ke's avatar
Guolin Ke committed
1274
1275
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1276
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1277
        return std::make_pair(idx, val);
1278
      };
Guolin Ke's avatar
Guolin Ke committed
1279
    }
Guolin Ke's avatar
Guolin Ke committed
1280
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1281
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1282
    if (col_ptr_type == C_API_DTYPE_INT32) {
1283
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1284
1285
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1286
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1287
1288
1289
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1290
        }
Guolin Ke's avatar
Guolin Ke committed
1291
1292
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1293
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1294
        return std::make_pair(idx, val);
1295
      };
Guolin Ke's avatar
Guolin Ke committed
1296
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1297
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1298
1299
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1300
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1301
1302
1303
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1304
        }
Guolin Ke's avatar
Guolin Ke committed
1305
1306
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1307
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1308
        return std::make_pair(idx, val);
1309
      };
Guolin Ke's avatar
Guolin Ke committed
1310
1311
1312
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1313
1314
}

Guolin Ke's avatar
Guolin Ke committed
1315
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1316
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1317
1318
1319
1320
1321
1322
1323
1324
1325
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1326
    }
Guolin Ke's avatar
Guolin Ke committed
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1344
    }
Guolin Ke's avatar
Guolin Ke committed
1345
1346
1347
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1348
  }
Guolin Ke's avatar
Guolin Ke committed
1349
}