c_api.cpp 47.2 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
39
          const char* parameters) {
wxchan's avatar
wxchan committed
40
41
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
55
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
74
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
95
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
96
97
98
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
119
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
130
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
145
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
146
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
164
165
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
166
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
167
168
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
169
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
170
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
171
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
172
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
173
174
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
175
    }
Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf);
    int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
    auto pred_fun = predictor.GetPredictFunction();
    auto pred_wrt_ptr = out_result;
    for (int i = 0; i < nrow; ++i) {
      auto one_row = get_row_fun(i);
      pred_fun(one_row, pred_wrt_ptr);
      pred_wrt_ptr += num_preb_in_one_row;
    }
    *out_len = nrow * num_preb_in_one_row;
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
               int data_has_header, const char* result_filename) {
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else {
      is_raw_score = false;
    }
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf);
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
203
204
  }

Guolin Ke's avatar
Guolin Ke committed
205
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
206
207
208
209
210
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
211
  }
212

213
214
215
216
217
218
219
220
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

221
222
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
223
  }
224

Guolin Ke's avatar
Guolin Ke committed
225
226
227
228
229
230
231
232
233
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
234
235
236
237
238
239
240
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
241

wxchan's avatar
wxchan committed
242
243
244
245
246
247
248
249
250
251
252
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
253
254
255
256
257
258
259
260
261
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
262
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
263

Guolin Ke's avatar
Guolin Ke committed
264
private:
265

wxchan's avatar
wxchan committed
266
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
267
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
268
269
270
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
271
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
272
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
273
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
274
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
275
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
276
277
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
278
279
280
};

}
Guolin Ke's avatar
Guolin Ke committed
281
282
283

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
288
289
290
291
292
293
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
294
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
295
296
297
298
299

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
300
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
316
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
317
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
318
319
}

Guolin Ke's avatar
Guolin Ke committed
320
int LGBM_DatasetCreateFromFile(const char* filename,
321
322
323
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
324
  API_BEGIN();
wxchan's avatar
wxchan committed
325
326
327
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
328
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
329
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
330
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
331
  } else {
Guolin Ke's avatar
Guolin Ke committed
332
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
333
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
334
  }
335
  API_END();
Guolin Ke's avatar
Guolin Ke committed
336
337
}

338

Guolin Ke's avatar
Guolin Ke committed
339
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
340
341
342
343
344
345
346
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
347
348
349
350
351
352
353
354
355
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr, 1, nullptr);
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
356
357
}

358

Guolin Ke's avatar
Guolin Ke committed
359
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
360
361
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
362
363
364
365
366
367
368
369
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
370
int LGBM_DatasetPushRows(DatasetHandle dataset,
371
372
373
374
375
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
376
377
378
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
379
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
380
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
381
  for (int i = 0; i < nrow; ++i) {
382
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
383
384
385
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
386
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
387
  }
388
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
389
390
391
392
393
394
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
395
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
396
397
398
399
400
401
402
403
404
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
405
406
407
408
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
409
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
410
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
411
  for (int i = 0; i < nrow; ++i) {
412
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
413
414
415
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
416
                          static_cast<data_size_t>(start_row + i), one_row);
417
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
418
  }
419
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
420
421
422
423
424
425
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
426
int LGBM_DatasetCreateFromMat(const void* data,
427
428
429
430
431
432
433
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
434
  API_BEGIN();
wxchan's avatar
wxchan committed
435
436
437
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
438
  std::unique_ptr<Dataset> ret;
439
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
440
441
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
442
    Random rand(io_config.data_random_seed);
443
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
444
    auto sample_indices = rand.Sample(nrow, sample_cnt);
445
    sample_cnt = static_cast<int>(sample_indices.size());
446
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
447
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
448
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
449
      auto idx = sample_indices[i];
450
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
451
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
452
453
454
        if (std::fabs(row[j]) > kEpsilon) {
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
455
        }
Guolin Ke's avatar
Guolin Ke committed
456
457
      }
    }
458
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
459
460
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
461
462
463
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
464
  } else {
465
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
466
    ret->CreateValid(
467
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
468
  }
469
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
470
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
471
  for (int i = 0; i < nrow; ++i) {
472
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
473
    const int tid = omp_get_thread_num();
474
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
475
    ret->PushOneRow(tid, i, one_row);
476
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
477
  }
478
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
479
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
480
  *out = ret.release();
481
  API_END();
482
483
}

Guolin Ke's avatar
Guolin Ke committed
484
int LGBM_DatasetCreateFromCSR(const void* indptr,
485
486
487
488
489
490
491
492
493
494
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
495
  API_BEGIN();
wxchan's avatar
wxchan committed
496
497
498
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
499
  std::unique_ptr<Dataset> ret;
500
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
501
502
503
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
504
    Random rand(io_config.data_random_seed);
505
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
506
    auto sample_indices = rand.Sample(nrow, sample_cnt);
507
    sample_cnt = static_cast<int>(sample_indices.size());
508
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
509
    std::vector<std::vector<int>> sample_idx;
510
511
512
513
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
514
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
515
516
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
517
        }
Guolin Ke's avatar
Guolin Ke committed
518
519
520
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
521
522
523
        }
      }
    }
524
    CHECK(num_col >= static_cast<int>(sample_values.size()));
525
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
526
527
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
528
529
530
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
531
  } else {
532
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
533
    ret->CreateValid(
534
      reinterpret_cast<const Dataset*>(reference));
535
  }
536
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
537
  #pragma omp parallel for schedule(static)
538
  for (int i = 0; i < nindptr - 1; ++i) {
539
    OMP_LOOP_EX_BEGIN();
540
541
542
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
543
    OMP_LOOP_EX_END();
544
  }
545
  OMP_THROW_EX();
546
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
547
  *out = ret.release();
548
  API_END();
549
550
}

Guolin Ke's avatar
Guolin Ke committed
551
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
552
553
554
555
556
557
558
559
560
561
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
562
  API_BEGIN();
wxchan's avatar
wxchan committed
563
564
565
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
566
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
567
568
569
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
570
    Random rand(io_config.data_random_seed);
571
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
572
    auto sample_indices = rand.Sample(nrow, sample_cnt);
573
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
574
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
575
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
576
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
577
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
578
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
579
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
580
581
582
583
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
Guolin Ke's avatar
Guolin Ke committed
584
585
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
586
587
        }
      }
588
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
589
    }
590
    OMP_THROW_EX();
591
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
592
593
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
594
595
596
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
597
  } else {
598
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
599
    ret->CreateValid(
600
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
601
  }
602
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
603
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
604
  for (int i = 0; i < ncol_ptr - 1; ++i) {
605
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
606
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
607
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
608
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
609
610
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
611
612
613
614
615
616
617
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
618
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
619
    }
620
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
621
  }
622
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
623
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
624
  *out = ret.release();
625
  API_END();
Guolin Ke's avatar
Guolin Ke committed
626
627
}

Guolin Ke's avatar
Guolin Ke committed
628
int LGBM_DatasetGetSubset(
629
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
630
631
632
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
633
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
634
635
636
637
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
638
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
639
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
640
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
641
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
642
643
644
645
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
646
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
647
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
648
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
649
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
650
651
652
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
653
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
654
655
656
657
658
659
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
660
int LGBM_DatasetGetFeatureNames(
661
662
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
663
  int* num_feature_names) {
664
665
666
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
667
668
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
669
670
671
672
673
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
674
int LGBM_DatasetFree(DatasetHandle handle) {
675
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
676
  delete reinterpret_cast<Dataset*>(handle);
677
  API_END();
678
679
}

Guolin Ke's avatar
Guolin Ke committed
680
int LGBM_DatasetSaveBinary(DatasetHandle handle,
681
                           const char* filename) {
682
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
683
684
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
685
  API_END();
686
687
}

Guolin Ke's avatar
Guolin Ke committed
688
int LGBM_DatasetSetField(DatasetHandle handle,
689
690
691
692
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
693
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
694
  auto dataset = reinterpret_cast<Dataset*>(handle);
695
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
696
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
697
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
698
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
699
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
700
701
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
702
  }
703
704
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
705
706
}

Guolin Ke's avatar
Guolin Ke committed
707
int LGBM_DatasetGetField(DatasetHandle handle,
708
709
710
711
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
712
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
713
  auto dataset = reinterpret_cast<Dataset*>(handle);
714
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
715
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
716
    *out_type = C_API_DTYPE_FLOAT32;
717
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
718
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
719
    *out_type = C_API_DTYPE_INT32;
720
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
721
722
723
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
724
  }
725
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
726
  if (*out_ptr == nullptr) { *out_len = 0; }
727
  API_END();
728
729
}

Guolin Ke's avatar
Guolin Ke committed
730
int LGBM_DatasetGetNumData(DatasetHandle handle,
731
                           int* out) {
732
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
733
734
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
735
  API_END();
736
737
}

Guolin Ke's avatar
Guolin Ke committed
738
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
739
                              int* out) {
740
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
741
742
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
743
  API_END();
Guolin Ke's avatar
Guolin Ke committed
744
}
745
746
747

// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
748
int LGBM_BoosterCreate(const DatasetHandle train_data,
749
750
                       const char* parameters,
                       BoosterHandle* out) {
751
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
752
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
753
754
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
755
  API_END();
756
757
}

Guolin Ke's avatar
Guolin Ke committed
758
int LGBM_BoosterCreateFromModelfile(
759
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
760
  int* out_num_iterations,
761
  BoosterHandle* out) {
762
  API_BEGIN();
wxchan's avatar
wxchan committed
763
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
764
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
765
  *out = ret.release();
766
  API_END();
767
768
}

Guolin Ke's avatar
Guolin Ke committed
769
int LGBM_BoosterLoadModelFromString(
770
771
772
773
774
775
776
777
778
779
780
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
781
int LGBM_BoosterFree(BoosterHandle handle) {
782
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
783
  delete reinterpret_cast<Booster*>(handle);
784
  API_END();
785
786
}

Guolin Ke's avatar
Guolin Ke committed
787
int LGBM_BoosterMerge(BoosterHandle handle,
788
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
789
790
791
792
793
794
795
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
796
int LGBM_BoosterAddValidData(BoosterHandle handle,
797
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
798
799
800
801
802
803
804
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
805
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
806
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
807
808
809
810
811
812
813
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
814
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
815
816
817
818
819
820
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
821
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
822
823
824
825
826
827
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
828
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
829
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
830
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
831
832
833
834
835
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
836
  API_END();
837
838
}

Guolin Ke's avatar
Guolin Ke committed
839
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
840
841
842
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
843
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
844
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
845
846
847
848
849
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
850
  API_END();
851
852
}

Guolin Ke's avatar
Guolin Ke committed
853
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
854
855
856
857
858
859
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
860
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
861
862
863
864
865
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
866

Guolin Ke's avatar
Guolin Ke committed
867
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
868
869
870
871
872
873
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
874
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
875
876
877
878
879
880
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
881
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
882
883
884
885
886
887
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
888
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
889
890
891
892
893
894
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
895
int LGBM_BoosterGetEval(BoosterHandle handle,
896
897
898
                        int data_idx,
                        int* out_len,
                        double* out_results) {
899
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
900
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
901
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
902
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
903
  *out_len = static_cast<int>(result_buf.size());
904
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
905
    (out_results)[i] = static_cast<double>(result_buf[i]);
906
  }
907
  API_END();
908
909
}

Guolin Ke's avatar
Guolin Ke committed
910
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
911
912
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
913
914
915
916
917
918
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
919
int LGBM_BoosterGetPredict(BoosterHandle handle,
920
921
922
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
923
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
924
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
925
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
926
  API_END();
Guolin Ke's avatar
Guolin Ke committed
927
928
}

Guolin Ke's avatar
Guolin Ke committed
929
int LGBM_BoosterPredictForFile(BoosterHandle handle,
930
931
932
933
934
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
                               const char* result_filename) {
935
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
936
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
937
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header, result_filename);
938
  API_END();
939
940
}

Guolin Ke's avatar
Guolin Ke committed
941
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
942
943
944
945
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
946
947
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
948
949
  *out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX));
Guolin Ke's avatar
Guolin Ke committed
950
951
952
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
953
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
954
955
956
957
958
959
960
961
962
963
964
965
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
                              int64_t* out_len,
                              double* out_result) {
966
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
967
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
968
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
969
  int nrow = static_cast<int>(nindptr - 1);
Guolin Ke's avatar
Guolin Ke committed
970
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len);
971
  API_END();
Guolin Ke's avatar
Guolin Ke committed
972
}
973

Guolin Ke's avatar
Guolin Ke committed
974
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
975
976
977
978
979
980
981
982
983
984
985
986
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
987
988
989
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
990
991
992
993
994
995
  std::vector<CSC_RowIterator> iterators;
  for (int j = 0; j < ncol; ++j) {
    iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
    [&iterators, ncol](int i) {
Guolin Ke's avatar
Guolin Ke committed
996
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
997
998
999
1000
    for (int j = 0; j < ncol; ++j) {
      auto val = iterators[j].Get(i);
      if (std::fabs(val) > kEpsilon) {
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1001
1002
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1003
1004
1005
    return one_row;
  };
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1006
1007
1008
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1009
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1010
1011
1012
1013
1014
1015
1016
1017
1018
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
                              int64_t* out_len,
                              double* out_result) {
1019
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1020
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1021
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
1022
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len);
1023
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1024
}
1025

Guolin Ke's avatar
Guolin Ke committed
1026
int LGBM_BoosterSaveModel(BoosterHandle handle,
1027
1028
                          int num_iteration,
                          const char* filename) {
1029
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1030
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1031
1032
1033
1034
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1035
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1036
1037
1038
1039
                                  int num_iteration,
                                  int buffer_len,
                                  int* out_len,
                                  char* out_str) {
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1050
int LGBM_BoosterDumpModel(BoosterHandle handle,
1051
1052
1053
1054
                          int num_iteration,
                          int buffer_len,
                          int* out_len,
                          char* out_str) {
wxchan's avatar
wxchan committed
1055
1056
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1057
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1058
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
1059
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1060
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
1061
  }
1062
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1063
}
1064

Guolin Ke's avatar
Guolin Ke committed
1065
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1066
1067
1068
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1069
1070
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1071
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1072
1073
1074
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1075
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1076
1077
1078
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1079
1080
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1081
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1082
1083
1084
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1085
// ---- start of some help functions
1086
1087
1088

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1089
  if (data_type == C_API_DTYPE_FLOAT32) {
1090
1091
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1092
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1093
        std::vector<double> ret(num_col);
1094
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1095
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1096
          ret[i] = static_cast<double>(*(tmp_ptr + i));
Guolin Ke's avatar
Guolin Ke committed
1097
1098
1099
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1100
1101
1102
1103
        }
        return ret;
      };
    } else {
Guolin Ke's avatar
Guolin Ke committed
1104
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1105
        std::vector<double> ret(num_col);
1106
        for (int i = 0; i < num_col; ++i) {
1107
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
Guolin Ke's avatar
Guolin Ke committed
1108
1109
1110
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1111
1112
1113
1114
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1115
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1116
1117
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1118
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1119
        std::vector<double> ret(num_col);
1120
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1121
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1122
          ret[i] = static_cast<double>(*(tmp_ptr + i));
Guolin Ke's avatar
Guolin Ke committed
1123
1124
1125
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1126
1127
1128
1129
        }
        return ret;
      };
    } else {
Guolin Ke's avatar
Guolin Ke committed
1130
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1131
        std::vector<double> ret(num_col);
1132
        for (int i = 0; i < num_col; ++i) {
1133
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
Guolin Ke's avatar
Guolin Ke committed
1134
1135
1136
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1137
1138
1139
1140
1141
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1142
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1143
1144
1145
1146
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1147
1148
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
1149
    return [inner_function](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1150
1151
1152
1153
1154
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1155
        }
Guolin Ke's avatar
Guolin Ke committed
1156
1157
1158
      }
      return ret;
    };
1159
  }
Guolin Ke's avatar
Guolin Ke committed
1160
  return nullptr;
1161
1162
1163
1164
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1165
  if (data_type == C_API_DTYPE_FLOAT32) {
1166
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1167
    if (indptr_type == C_API_DTYPE_INT32) {
1168
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1169
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1170
1171
1172
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1173
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1174
1175
1176
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1177
1178
1179
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1180
    } else if (indptr_type == C_API_DTYPE_INT64) {
1181
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1182
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1183
1184
1185
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1186
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1187
1188
1189
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1190
1191
1192
1193
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1194
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1195
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1196
    if (indptr_type == C_API_DTYPE_INT32) {
1197
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1198
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1199
1200
1201
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1202
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1203
1204
1205
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1206
1207
1208
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1209
    } else if (indptr_type == C_API_DTYPE_INT64) {
1210
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1211
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1212
1213
1214
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1215
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1216
1217
1218
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1219
1220
1221
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1222
1223
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1224
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1225
1226
}

Guolin Ke's avatar
Guolin Ke committed
1227
1228
1229
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1230
  if (data_type == C_API_DTYPE_FLOAT32) {
1231
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1232
    if (col_ptr_type == C_API_DTYPE_INT32) {
1233
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1234
1235
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1236
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1237
1238
1239
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1240
        }
Guolin Ke's avatar
Guolin Ke committed
1241
1242
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1243
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1244
        return std::make_pair(idx, val);
1245
      };
Guolin Ke's avatar
Guolin Ke committed
1246
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1247
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1248
1249
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1250
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1251
1252
1253
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1254
        }
Guolin Ke's avatar
Guolin Ke committed
1255
1256
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1257
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1258
        return std::make_pair(idx, val);
1259
      };
Guolin Ke's avatar
Guolin Ke committed
1260
    }
Guolin Ke's avatar
Guolin Ke committed
1261
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1262
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1263
    if (col_ptr_type == C_API_DTYPE_INT32) {
1264
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1265
1266
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1267
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1268
1269
1270
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1271
        }
Guolin Ke's avatar
Guolin Ke committed
1272
1273
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1274
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1275
        return std::make_pair(idx, val);
1276
      };
Guolin Ke's avatar
Guolin Ke committed
1277
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1278
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1279
1280
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1281
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1282
1283
1284
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1285
        }
Guolin Ke's avatar
Guolin Ke committed
1286
1287
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1288
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1289
        return std::make_pair(idx, val);
1290
      };
Guolin Ke's avatar
Guolin Ke committed
1291
1292
1293
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1294
1295
}

Guolin Ke's avatar
Guolin Ke committed
1296
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1297
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1298
1299
1300
1301
1302
1303
1304
1305
1306
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1307
    }
Guolin Ke's avatar
Guolin Ke committed
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1325
    }
Guolin Ke's avatar
Guolin Ke committed
1326
1327
1328
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1329
  }
Guolin Ke's avatar
Guolin Ke committed
1330
}