c_api.cpp 47.2 KB
Newer Older
1
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
2
3
4

#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
11
12
13
14
15
16
17
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
18
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
19
#include <stdexcept>
wxchan's avatar
wxchan committed
20
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
21
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
22

Guolin Ke's avatar
Guolin Ke committed
23
#include "./application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
24
#include "./boosting/gbdt.h"
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
namespace LightGBM {

class Booster {
public:
Guolin Ke's avatar
Guolin Ke committed
30
31
  explicit Booster(const char* filename) {
    boosting_.reset(Boosting::CreateBoosting(filename));
Guolin Ke's avatar
Guolin Ke committed
32
33
  }

34
35
36
37
  Booster() {
    boosting_.reset(Boosting::CreateBoosting("gbdt", nullptr));
  }

Guolin Ke's avatar
Guolin Ke committed
38
  Booster(const Dataset* train_data,
39
          const char* parameters) {
wxchan's avatar
wxchan committed
40
41
    auto param = ConfigBase::Str2Map(parameters);
    config_.Set(param);
42
43
44
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
45
46
    // create boosting
    if (config_.io_config.input_model.size() > 0) {
Guolin Ke's avatar
Guolin Ke committed
47
      Log::Warning("continued train from model is not support for c_api, \
Guolin Ke's avatar
Guolin Ke committed
48
49
        please use continued train with input score");
    }
Guolin Ke's avatar
Guolin Ke committed
50

wxchan's avatar
wxchan committed
51
    boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
54
    boosting_->Init(&config_.boosting_config, nullptr, objective_fun_.get(),
55
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
56
57

    ResetTrainingData(train_data);
wxchan's avatar
wxchan committed
58
59
60
61
62
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
  }

  ~Booster() {
Guolin Ke's avatar
Guolin Ke committed
66

Guolin Ke's avatar
Guolin Ke committed
67
  }
68

wxchan's avatar
wxchan committed
69
70
71
  void ResetTrainingData(const Dataset* train_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    train_data_ = train_data;
Guolin Ke's avatar
Guolin Ke committed
72
73
    // create objective function
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
74
                                                                    config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(
        Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
    // reset the boosting
Guolin Ke's avatar
Guolin Ke committed
94
    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
95
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
96
97
98
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
99
    std::lock_guard<std::mutex> lock(mutex_);
wxchan's avatar
wxchan committed
100
101
102
103
104
105
106
    auto param = ConfigBase::Str2Map(parameters);
    if (param.count("num_class")) {
      Log::Fatal("cannot change num class during training");
    }
    if (param.count("boosting_type")) {
      Log::Fatal("cannot change boosting_type during training");
    }
Guolin Ke's avatar
Guolin Ke committed
107
108
109
    if (param.count("metric")) {
      Log::Fatal("cannot change metric during training");
    }
Guolin Ke's avatar
Guolin Ke committed
110
111

    config_.Set(param);
112
113
114
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118

    if (param.count("objective")) {
      // create objective function
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
119
                                                                      config_.objective_config));
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
wxchan's avatar
wxchan committed
127
    }
Guolin Ke's avatar
Guolin Ke committed
128
129

    boosting_->ResetTrainingData(&config_.boosting_config, train_data_,
130
                                 objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
Guolin Ke's avatar
Guolin Ke committed
131

wxchan's avatar
wxchan committed
132
133
134
135
136
137
138
139
140
141
142
143
144
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
    for (auto metric_type : config_.metric_types) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
145
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
146
  }
Guolin Ke's avatar
Guolin Ke committed
147

148
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
149
    std::lock_guard<std::mutex> lock(mutex_);
150
151
152
153
    return boosting_->TrainOneIter(nullptr, nullptr, false);
  }

  bool TrainOneIter(const float* gradients, const float* hessians) {
wxchan's avatar
wxchan committed
154
    std::lock_guard<std::mutex> lock(mutex_);
155
156
157
    return boosting_->TrainOneIter(gradients, hessians, false);
  }

wxchan's avatar
wxchan committed
158
159
160
161
162
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

Guolin Ke's avatar
Guolin Ke committed
163
164
165
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
166
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
167
168
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
169
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
170
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
171
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
172
      is_raw_score = true;
Guolin Ke's avatar
Guolin Ke committed
173
174
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
175
    }
Guolin Ke's avatar
Guolin Ke committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf);
    int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
    auto pred_fun = predictor.GetPredictFunction();
    auto pred_wrt_ptr = out_result;
    for (int i = 0; i < nrow; ++i) {
      auto one_row = get_row_fun(i);
      pred_fun(one_row, pred_wrt_ptr);
      pred_wrt_ptr += num_preb_in_one_row;
    }
    *out_len = nrow * num_preb_in_one_row;
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
               int data_has_header, const char* result_filename) {
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else {
      is_raw_score = false;
    }
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf);
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
203
204
  }

Guolin Ke's avatar
Guolin Ke committed
205
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
206
207
208
209
210
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

  void SaveModelToFile(int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
211
  }
212

213
214
215
216
217
218
219
220
  void LoadModelFromString(const char* model_str) {
    boosting_->LoadModelFromString(model_str);
  }

  std::string SaveModelToString(int num_iteration) {
    return boosting_->SaveModelToString(num_iteration);
  }

221
222
  std::string DumpModel(int num_iteration) {
    return boosting_->DumpModel(num_iteration);
wxchan's avatar
wxchan committed
223
  }
224

Guolin Ke's avatar
Guolin Ke committed
225
226
227
228
229
230
231
232
233
  double GetLeafValue(int tree_idx, int leaf_idx) const {
    return dynamic_cast<GBDT*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
    dynamic_cast<GBDT*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
  }

wxchan's avatar
wxchan committed
234
235
236
237
238
239
240
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
241

wxchan's avatar
wxchan committed
242
243
244
245
246
247
248
249
250
251
252
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
        std::strcpy(out_strs[idx], name.c_str());
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
253
254
255
256
257
258
259
260
261
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
      std::strcpy(out_strs[idx], name.c_str());
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
262
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
263

Guolin Ke's avatar
Guolin Ke committed
264
private:
265

wxchan's avatar
wxchan committed
266
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
267
  std::unique_ptr<Boosting> boosting_;
Guolin Ke's avatar
Guolin Ke committed
268
269
270
  /*! \brief All configs */
  OverallConfig config_;
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
271
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
272
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
273
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
274
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
275
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
276
277
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
278
279
280
};

}
Guolin Ke's avatar
Guolin Ke committed
281
282
283

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
288
289
290
291
292
293
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
294
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
295
296
297
298
299

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
public:
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
300
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
private:
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
316
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
317
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
318
319
}

Guolin Ke's avatar
Guolin Ke committed
320
int LGBM_DatasetCreateFromFile(const char* filename,
321
322
323
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
324
  API_BEGIN();
wxchan's avatar
wxchan committed
325
326
327
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
328
  DatasetLoader loader(io_config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
329
  if (reference == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
330
    *out = loader.LoadFromFile(filename);
Guolin Ke's avatar
Guolin Ke committed
331
  } else {
Guolin Ke's avatar
Guolin Ke committed
332
    *out = loader.LoadFromFileAlignWithOtherDataset(filename,
333
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
334
  }
335
  API_END();
Guolin Ke's avatar
Guolin Ke committed
336
337
}

338

Guolin Ke's avatar
Guolin Ke committed
339
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
340
341
342
343
344
345
346
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
347
348
349
350
351
352
353
354
355
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
  DatasetLoader loader(io_config, nullptr, 1, nullptr);
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
356
357
}

358

Guolin Ke's avatar
Guolin Ke committed
359
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
360
361
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
362
363
364
365
366
367
368
369
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
370
int LGBM_DatasetPushRows(DatasetHandle dataset,
371
372
373
374
375
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
376
377
378
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
379
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
380
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
381
  for (int i = 0; i < nrow; ++i) {
382
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
383
384
385
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
386
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
387
  }
388
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
389
390
391
392
393
394
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
395
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
396
397
398
399
400
401
402
403
404
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
405
406
407
408
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
409
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
410
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
411
  for (int i = 0; i < nrow; ++i) {
412
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
413
414
415
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
416
                          static_cast<data_size_t>(start_row + i), one_row);
417
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
418
  }
419
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
420
421
422
423
424
425
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
426
int LGBM_DatasetCreateFromMat(const void* data,
427
428
429
430
431
432
433
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
434
  API_BEGIN();
wxchan's avatar
wxchan committed
435
436
437
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
438
  std::unique_ptr<Dataset> ret;
439
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
440
441
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
442
    Random rand(io_config.data_random_seed);
443
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
444
    auto sample_indices = rand.Sample(nrow, sample_cnt);
445
    sample_cnt = static_cast<int>(sample_indices.size());
446
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
447
    std::vector<std::vector<int>> sample_idx(ncol);
Guolin Ke's avatar
Guolin Ke committed
448
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
449
      auto idx = sample_indices[i];
450
      auto row = get_row_fun(static_cast<int>(idx));
Guolin Ke's avatar
Guolin Ke committed
451
      for (size_t j = 0; j < row.size(); ++j) {
Guolin Ke's avatar
Guolin Ke committed
452
453
454
        if (std::fabs(row[j]) > kEpsilon) {
          sample_values[j].emplace_back(row[j]);
          sample_idx[j].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
455
        }
Guolin Ke's avatar
Guolin Ke committed
456
457
      }
    }
458
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
459
460
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
461
462
463
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
464
  } else {
465
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
466
    ret->CreateValid(
467
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
468
  }
469
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
470
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
471
  for (int i = 0; i < nrow; ++i) {
472
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
473
    const int tid = omp_get_thread_num();
474
    auto one_row = get_row_fun(i);
Guolin Ke's avatar
Guolin Ke committed
475
    ret->PushOneRow(tid, i, one_row);
476
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
477
  }
478
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
479
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
480
  *out = ret.release();
481
  API_END();
482
483
}

Guolin Ke's avatar
Guolin Ke committed
484
int LGBM_DatasetCreateFromCSR(const void* indptr,
485
486
487
488
489
490
491
492
493
494
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
495
  API_BEGIN();
wxchan's avatar
wxchan committed
496
497
498
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
499
  std::unique_ptr<Dataset> ret;
500
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
501
502
503
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
504
    Random rand(io_config.data_random_seed);
505
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
506
    auto sample_indices = rand.Sample(nrow, sample_cnt);
507
    sample_cnt = static_cast<int>(sample_indices.size());
508
    std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
509
    std::vector<std::vector<int>> sample_idx;
510
511
512
513
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
514
        if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
515
516
          sample_values.resize(inner_data.first + 1);
          sample_idx.resize(inner_data.first + 1);
517
        }
Guolin Ke's avatar
Guolin Ke committed
518
519
520
        if (std::fabs(inner_data.second) > kEpsilon) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
521
522
523
        }
      }
    }
524
    CHECK(num_col >= static_cast<int>(sample_values.size()));
525
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
526
527
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
528
529
530
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
531
  } else {
532
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
533
    ret->CreateValid(
534
      reinterpret_cast<const Dataset*>(reference));
535
  }
536
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
537
  #pragma omp parallel for schedule(static)
538
  for (int i = 0; i < nindptr - 1; ++i) {
539
    OMP_LOOP_EX_BEGIN();
540
541
542
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
543
    OMP_LOOP_EX_END();
544
  }
545
  OMP_THROW_EX();
546
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
547
  *out = ret.release();
548
  API_END();
549
550
}

Guolin Ke's avatar
Guolin Ke committed
551
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
552
553
554
555
556
557
558
559
560
561
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
562
  API_BEGIN();
wxchan's avatar
wxchan committed
563
564
565
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
Guolin Ke's avatar
Guolin Ke committed
566
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
567
568
569
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
wxchan's avatar
wxchan committed
570
    Random rand(io_config.data_random_seed);
571
    int sample_cnt = static_cast<int>(nrow < io_config.bin_construct_sample_cnt ? nrow : io_config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
572
    auto sample_indices = rand.Sample(nrow, sample_cnt);
573
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
574
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
575
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
576
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
577
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
578
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
579
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
580
581
582
583
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
        if (std::fabs(val) > kEpsilon) {
Guolin Ke's avatar
Guolin Ke committed
584
585
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
586
587
        }
      }
588
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
589
    }
590
    OMP_THROW_EX();
591
    DatasetLoader loader(io_config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
592
593
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
594
595
596
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
597
  } else {
598
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
599
    ret->CreateValid(
600
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
601
  }
602
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
603
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
604
  for (int i = 0; i < ncol_ptr - 1; ++i) {
605
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
606
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
607
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
608
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
609
610
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
611
612
613
614
615
616
617
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
618
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
619
    }
620
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
621
  }
622
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
623
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
624
  *out = ret.release();
625
  API_END();
Guolin Ke's avatar
Guolin Ke committed
626
627
}

Guolin Ke's avatar
Guolin Ke committed
628
int LGBM_DatasetGetSubset(
629
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
630
631
632
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
633
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
634
635
636
637
  API_BEGIN();
  auto param = ConfigBase::Str2Map(parameters);
  IOConfig io_config;
  io_config.Set(param);
638
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
639
  CHECK(num_used_row_indices > 0);
Guolin Ke's avatar
Guolin Ke committed
640
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
641
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
642
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
643
644
645
646
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
647
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
648
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
649
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
650
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
651
652
653
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
654
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
655
656
657
658
659
660
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
661
int LGBM_DatasetGetFeatureNames(
662
663
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
664
  int* num_feature_names) {
665
666
667
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
668
669
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
670
671
672
673
674
    std::strcpy(feature_names[i], inside_feature_name[i].c_str());
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
675
int LGBM_DatasetFree(DatasetHandle handle) {
676
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
677
  delete reinterpret_cast<Dataset*>(handle);
678
  API_END();
679
680
}

Guolin Ke's avatar
Guolin Ke committed
681
int LGBM_DatasetSaveBinary(DatasetHandle handle,
682
                           const char* filename) {
683
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
684
685
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
686
  API_END();
687
688
}

Guolin Ke's avatar
Guolin Ke committed
689
int LGBM_DatasetSetField(DatasetHandle handle,
690
691
692
693
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
694
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
695
  auto dataset = reinterpret_cast<Dataset*>(handle);
696
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
697
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
698
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
699
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
700
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
701
702
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
703
  }
704
705
  if (!is_success) { throw std::runtime_error("Input data type erorr or field not found"); }
  API_END();
706
707
}

Guolin Ke's avatar
Guolin Ke committed
708
int LGBM_DatasetGetField(DatasetHandle handle,
709
710
711
712
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
713
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
714
  auto dataset = reinterpret_cast<Dataset*>(handle);
715
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
716
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
717
    *out_type = C_API_DTYPE_FLOAT32;
718
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
719
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
720
    *out_type = C_API_DTYPE_INT32;
721
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
722
723
724
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
725
  }
726
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
727
  if (*out_ptr == nullptr) { *out_len = 0; }
728
  API_END();
729
730
}

Guolin Ke's avatar
Guolin Ke committed
731
int LGBM_DatasetGetNumData(DatasetHandle handle,
732
                           int* out) {
733
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
734
735
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
736
  API_END();
737
738
}

Guolin Ke's avatar
Guolin Ke committed
739
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
740
                              int* out) {
741
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
742
743
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
744
  API_END();
Guolin Ke's avatar
Guolin Ke committed
745
}
746
747
748

// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
749
int LGBM_BoosterCreate(const DatasetHandle train_data,
750
751
                       const char* parameters,
                       BoosterHandle* out) {
752
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
753
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
754
755
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
756
  API_END();
757
758
}

Guolin Ke's avatar
Guolin Ke committed
759
int LGBM_BoosterCreateFromModelfile(
760
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
761
  int* out_num_iterations,
762
  BoosterHandle* out) {
763
  API_BEGIN();
wxchan's avatar
wxchan committed
764
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
765
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
766
  *out = ret.release();
767
  API_END();
768
769
}

Guolin Ke's avatar
Guolin Ke committed
770
int LGBM_BoosterLoadModelFromString(
771
772
773
774
775
776
777
778
779
780
781
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
  auto ret = std::unique_ptr<Booster>(new Booster());
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
782
int LGBM_BoosterFree(BoosterHandle handle) {
783
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
784
  delete reinterpret_cast<Booster*>(handle);
785
  API_END();
786
787
}

Guolin Ke's avatar
Guolin Ke committed
788
int LGBM_BoosterMerge(BoosterHandle handle,
789
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
790
791
792
793
794
795
796
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
797
int LGBM_BoosterAddValidData(BoosterHandle handle,
798
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
799
800
801
802
803
804
805
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
806
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
807
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
808
809
810
811
812
813
814
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
815
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
816
817
818
819
820
821
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
822
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
823
824
825
826
827
828
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
829
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
830
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
831
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
832
833
834
835
836
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
837
  API_END();
838
839
}

Guolin Ke's avatar
Guolin Ke committed
840
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
841
842
843
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
844
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
845
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
846
847
848
849
850
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
851
  API_END();
852
853
}

Guolin Ke's avatar
Guolin Ke committed
854
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
855
856
857
858
859
860
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
861
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
862
863
864
865
866
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
867

Guolin Ke's avatar
Guolin Ke committed
868
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
869
870
871
872
873
874
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
875
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
876
877
878
879
880
881
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
882
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
883
884
885
886
887
888
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
889
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
890
891
892
893
894
895
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
896
int LGBM_BoosterGetEval(BoosterHandle handle,
897
898
899
                        int data_idx,
                        int* out_len,
                        double* out_results) {
900
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
901
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
902
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
903
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
904
  *out_len = static_cast<int>(result_buf.size());
905
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
906
    (out_results)[i] = static_cast<double>(result_buf[i]);
907
  }
908
  API_END();
909
910
}

Guolin Ke's avatar
Guolin Ke committed
911
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
912
913
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
914
915
916
917
918
919
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
920
int LGBM_BoosterGetPredict(BoosterHandle handle,
921
922
923
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
924
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
925
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
926
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
927
  API_END();
Guolin Ke's avatar
Guolin Ke committed
928
929
}

Guolin Ke's avatar
Guolin Ke committed
930
int LGBM_BoosterPredictForFile(BoosterHandle handle,
931
932
933
934
935
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
                               const char* result_filename) {
936
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
937
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
938
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header, result_filename);
939
  API_END();
940
941
}

Guolin Ke's avatar
Guolin Ke committed
942
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
943
944
945
946
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
947
948
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
949
950
  *out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX));
Guolin Ke's avatar
Guolin Ke committed
951
952
953
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
954
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
955
956
957
958
959
960
961
962
963
964
965
966
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
                              int64_t* out_len,
                              double* out_result) {
967
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
968
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
969
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
970
  int nrow = static_cast<int>(nindptr - 1);
Guolin Ke's avatar
Guolin Ke committed
971
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len);
972
  API_END();
Guolin Ke's avatar
Guolin Ke committed
973
}
974

Guolin Ke's avatar
Guolin Ke committed
975
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
976
977
978
979
980
981
982
983
984
985
986
987
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
988
989
990
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
991
992
993
994
995
996
  std::vector<CSC_RowIterator> iterators;
  for (int j = 0; j < ncol; ++j) {
    iterators.emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
    [&iterators, ncol](int i) {
Guolin Ke's avatar
Guolin Ke committed
997
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
998
999
1000
1001
    for (int j = 0; j < ncol; ++j) {
      auto val = iterators[j].Get(i);
      if (std::fabs(val) > kEpsilon) {
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1002
1003
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1004
1005
1006
    return one_row;
  };
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1007
1008
1009
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1010
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1011
1012
1013
1014
1015
1016
1017
1018
1019
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
                              int64_t* out_len,
                              double* out_result) {
1020
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1021
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1022
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
Guolin Ke's avatar
Guolin Ke committed
1023
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, out_result, out_len);
1024
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1025
}
1026

Guolin Ke's avatar
Guolin Ke committed
1027
int LGBM_BoosterSaveModel(BoosterHandle handle,
1028
1029
                          int num_iteration,
                          const char* filename) {
1030
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1031
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
wxchan's avatar
wxchan committed
1032
1033
1034
1035
  ref_booster->SaveModelToFile(num_iteration, filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1036
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1037
1038
1039
1040
                                  int num_iteration,
                                  int buffer_len,
                                  int* out_len,
                                  char* out_str) {
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::string model = ref_booster->SaveModelToString(num_iteration);
  *out_len = static_cast<int>(model.size()) + 1;
  if (*out_len <= buffer_len) {
    std::strcpy(out_str, model.c_str());
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1051
int LGBM_BoosterDumpModel(BoosterHandle handle,
1052
1053
1054
1055
                          int num_iteration,
                          int buffer_len,
                          int* out_len,
                          char* out_str) {
wxchan's avatar
wxchan committed
1056
1057
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1058
  std::string model = ref_booster->DumpModel(num_iteration);
Guolin Ke's avatar
Guolin Ke committed
1059
  *out_len = static_cast<int>(model.size()) + 1;
wxchan's avatar
wxchan committed
1060
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1061
    std::strcpy(out_str, model.c_str());
wxchan's avatar
wxchan committed
1062
  }
1063
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1064
}
1065

Guolin Ke's avatar
Guolin Ke committed
1066
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1067
1068
1069
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1070
1071
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1072
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1073
1074
1075
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1076
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1077
1078
1079
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1080
1081
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1082
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1083
1084
1085
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1086
// ---- start of some help functions
1087
1088
1089

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1090
  if (data_type == C_API_DTYPE_FLOAT32) {
1091
1092
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1093
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1094
        std::vector<double> ret(num_col);
1095
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1096
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1097
          ret[i] = static_cast<double>(*(tmp_ptr + i));
Guolin Ke's avatar
Guolin Ke committed
1098
1099
1100
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1101
1102
1103
1104
        }
        return ret;
      };
    } else {
Guolin Ke's avatar
Guolin Ke committed
1105
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1106
        std::vector<double> ret(num_col);
1107
        for (int i = 0; i < num_col; ++i) {
1108
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
Guolin Ke's avatar
Guolin Ke committed
1109
1110
1111
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1112
1113
1114
1115
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1116
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1117
1118
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1119
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1120
        std::vector<double> ret(num_col);
1121
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1122
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1123
          ret[i] = static_cast<double>(*(tmp_ptr + i));
Guolin Ke's avatar
Guolin Ke committed
1124
1125
1126
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1127
1128
1129
1130
        }
        return ret;
      };
    } else {
Guolin Ke's avatar
Guolin Ke committed
1131
      return [data_ptr, num_col, num_row](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1132
        std::vector<double> ret(num_col);
1133
        for (int i = 0; i < num_col; ++i) {
1134
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
Guolin Ke's avatar
Guolin Ke committed
1135
1136
1137
          if (std::isnan(ret[i])) {
            ret[i] = 0.0f;
          }
1138
1139
1140
1141
1142
        }
        return ret;
      };
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1143
  throw std::runtime_error("unknown data type in RowFunctionFromDenseMatric");
1144
1145
1146
1147
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1148
1149
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
1150
    return [inner_function](int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1151
1152
1153
1154
1155
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
        if (std::fabs(raw_values[i]) > 1e-15) {
          ret.emplace_back(i, raw_values[i]);
1156
        }
Guolin Ke's avatar
Guolin Ke committed
1157
1158
1159
      }
      return ret;
    };
1160
  }
Guolin Ke's avatar
Guolin Ke committed
1161
  return nullptr;
1162
1163
1164
1165
}

std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t nindptr, int64_t nelem) {
Guolin Ke's avatar
Guolin Ke committed
1166
  if (data_type == C_API_DTYPE_FLOAT32) {
1167
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1168
    if (indptr_type == C_API_DTYPE_INT32) {
1169
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1170
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1171
1172
1173
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1174
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1175
1176
1177
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1178
1179
1180
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1181
    } else if (indptr_type == C_API_DTYPE_INT64) {
1182
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1183
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1184
1185
1186
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1187
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1188
1189
1190
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1191
1192
1193
1194
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1195
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1196
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1197
    if (indptr_type == C_API_DTYPE_INT32) {
1198
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1199
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1200
1201
1202
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1203
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1204
1205
1206
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1207
1208
1209
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1210
    } else if (indptr_type == C_API_DTYPE_INT64) {
1211
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
Guolin Ke's avatar
Guolin Ke committed
1212
      return [ptr_indptr, indices, data_ptr, nindptr, nelem](int idx) {
1213
1214
1215
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1216
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1217
1218
1219
          if (!std::isnan(data_ptr[i])) {
            ret.emplace_back(indices[i], data_ptr[i]);
          }
1220
1221
1222
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1223
1224
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1225
  throw std::runtime_error("unknown data type in RowFunctionFromCSR");
1226
1227
}

Guolin Ke's avatar
Guolin Ke committed
1228
1229
1230
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1231
  if (data_type == C_API_DTYPE_FLOAT32) {
1232
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1233
    if (col_ptr_type == C_API_DTYPE_INT32) {
1234
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1235
1236
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1237
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1238
1239
1240
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1241
        }
Guolin Ke's avatar
Guolin Ke committed
1242
1243
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1244
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1245
        return std::make_pair(idx, val);
1246
      };
Guolin Ke's avatar
Guolin Ke committed
1247
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1248
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1249
1250
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1251
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1252
1253
1254
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1255
        }
Guolin Ke's avatar
Guolin Ke committed
1256
1257
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1258
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1259
        return std::make_pair(idx, val);
1260
      };
Guolin Ke's avatar
Guolin Ke committed
1261
    }
Guolin Ke's avatar
Guolin Ke committed
1262
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1263
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1264
    if (col_ptr_type == C_API_DTYPE_INT32) {
1265
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1266
1267
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1268
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1269
1270
1271
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1272
        }
Guolin Ke's avatar
Guolin Ke committed
1273
1274
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1275
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1276
        return std::make_pair(idx, val);
1277
      };
Guolin Ke's avatar
Guolin Ke committed
1278
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1279
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1280
1281
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
Guolin Ke's avatar
Guolin Ke committed
1282
      return [ptr_col_ptr, indices, data_ptr, ncol_ptr, nelem, start, end](int bias) {
Guolin Ke's avatar
Guolin Ke committed
1283
1284
1285
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1286
        }
Guolin Ke's avatar
Guolin Ke committed
1287
1288
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
Guolin Ke's avatar
Guolin Ke committed
1289
        if (std::isnan(val)) { val = 0.0f; }
Guolin Ke's avatar
Guolin Ke committed
1290
        return std::make_pair(idx, val);
1291
      };
Guolin Ke's avatar
Guolin Ke committed
1292
1293
1294
    }
  }
  throw std::runtime_error("unknown data type in CSC matrix");
1295
1296
}

Guolin Ke's avatar
Guolin Ke committed
1297
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1298
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1299
1300
1301
1302
1303
1304
1305
1306
1307
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1308
    }
Guolin Ke's avatar
Guolin Ke committed
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1326
    }
Guolin Ke's avatar
Guolin Ke committed
1327
1328
1329
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1330
  }
Guolin Ke's avatar
Guolin Ke committed
1331
}