c_api.cpp 63.9 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
2

Guolin Ke's avatar
Guolin Ke committed
3
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
4
5
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/utils/log.h>
Guolin Ke's avatar
Guolin Ke committed
8
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
9
10
11
12
13
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>
cbecker's avatar
cbecker committed
14
#include <LightGBM/prediction_early_stop.h>
15
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
19
20

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
21
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
22
#include <stdexcept>
wxchan's avatar
wxchan committed
23
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
24
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
28
29
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

Guolin Ke's avatar
Guolin Ke committed
46
class Booster {
Nikita Titov's avatar
Nikita Titov committed
47
 public:
Guolin Ke's avatar
Guolin Ke committed
48
  explicit Booster(const char* filename) {
49
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
50
51
  }

Guolin Ke's avatar
Guolin Ke committed
52
  Booster(const Dataset* train_data,
53
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
54
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
55
    config_.Set(param);
56
57
58
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
59
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
60
    if (config_.input_model.size() > 0) {
61
62
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
63
    }
Guolin Ke's avatar
Guolin Ke committed
64

Guolin Ke's avatar
Guolin Ke committed
65
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
66

67
68
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
69
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
70
    if (config_.tree_learner == std::string("feature")) {
71
      Log::Fatal("Do not support feature parallel in c api");
72
    }
Guolin Ke's avatar
Guolin Ke committed
73
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
74
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
75
      config_.tree_learner = "serial";
76
    }
Guolin Ke's avatar
Guolin Ke committed
77
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
78
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
79
80
81
82
83
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
  }

  ~Booster() {
  }
88

89
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
90
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
91
92
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
97
98
99
100
101
102
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
103
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
104
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
105
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
111
112
113
114
115
116
117
118
119
120
121
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
122
123
124
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
125
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
126
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
127
    if (param.count("num_class")) {
128
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
129
    }
Guolin Ke's avatar
Guolin Ke committed
130
131
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
132
    }
Guolin Ke's avatar
Guolin Ke committed
133
    if (param.count("metric")) {
134
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
135
    }
Guolin Ke's avatar
Guolin Ke committed
136
137

    config_.Set(param);
138
139
140
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
141
142
143

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
144
145
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
151
152
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
153
154
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
155
    }
Guolin Ke's avatar
Guolin Ke committed
156

Guolin Ke's avatar
Guolin Ke committed
157
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
158
159
160
161
162
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
163
164
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
165
166
167
168
169
170
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
171
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
172
  }
Guolin Ke's avatar
Guolin Ke committed
173

174
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
175
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
176
    return boosting_->TrainOneIter(nullptr, nullptr);
177
178
  }

Guolin Ke's avatar
Guolin Ke committed
179
180
181
182
183
184
185
186
187
188
189
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

190
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
191
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
192
    return boosting_->TrainOneIter(gradients, hessians);
193
194
  }

wxchan's avatar
wxchan committed
195
196
197
198
199
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
  void PredictSingleRow(int num_iteration, int predict_type,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
    std::lock_guard<std::mutex> lock(mutex_);

    if (single_row_predictor_.get() == nullptr) {
      bool is_predict_leaf = false;
      bool is_raw_score = false;
      bool predict_contrib = false;
      if (predict_type == C_API_PREDICT_LEAF_INDEX) {
        is_predict_leaf = true;
      } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
        is_raw_score = true;
      } else if (predict_type == C_API_PREDICT_CONTRIB) {
        predict_contrib = true;
      } else {
        is_raw_score = false;
      }

      // TODO: config could be optimized away... (maybe using lambda callback?)
      single_row_predictor_.reset(new Predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin));
      single_row_num_pred_in_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
      single_row_predict_function_ = single_row_predictor_->GetPredictFunction();
    }

    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
    single_row_predict_function_(one_row, pred_wrt_ptr);

    *out_len = single_row_num_pred_in_one_row_;
  }


Guolin Ke's avatar
Guolin Ke committed
235
236
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
237
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
238
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
239
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
240
241
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
242
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
243
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
244
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
245
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
246
      is_raw_score = true;
247
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
248
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
249
250
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
251
    }
Guolin Ke's avatar
Guolin Ke committed
252

Guolin Ke's avatar
Guolin Ke committed
253
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
254
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
255
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
256
    auto pred_fun = predictor.GetPredictFunction();
257
258
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
259
    for (int i = 0; i < nrow; ++i) {
260
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
261
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
262
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
263
      pred_fun(one_row, pred_wrt_ptr);
264
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
265
    }
266
    OMP_THROW_EX();
267
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
268
269
270
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
271
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
272
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
273
274
275
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
276
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
277
278
279
280
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
281
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
282
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
283
284
285
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
286
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
287
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
288
289
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
290
291
  }

Guolin Ke's avatar
Guolin Ke committed
292
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
293
294
295
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

296
297
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
298
  }
299

300
  void LoadModelFromString(const char* model_str) {
301
302
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
303
304
  }

305
306
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
307
308
  }

309
  std::string DumpModel(int start_iteration, int num_iteration) {
310
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
311
  }
312

313
314
315
316
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
317
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
318
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
323
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
324
325
  }

326
  void ShuffleModels(int start_iter, int end_iter) {
327
    std::lock_guard<std::mutex> lock(mutex_);
328
    boosting_->ShuffleModels(start_iter, end_iter);
329
330
  }

wxchan's avatar
wxchan committed
331
332
333
334
335
336
337
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
338

wxchan's avatar
wxchan committed
339
340
341
342
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
343
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
344
345
346
347
348
349
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
350
351
352
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
353
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
354
355
356
357
358
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
359
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
360

Nikita Titov's avatar
Nikita Titov committed
361
 private:
wxchan's avatar
wxchan committed
362
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
363
  std::unique_ptr<Boosting> boosting_;
364
365
366
367
  std::unique_ptr<Predictor> single_row_predictor_;
  PredictFunction single_row_predict_function_;
  int64_t single_row_num_pred_in_one_row_;

Guolin Ke's avatar
Guolin Ke committed
368
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
369
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
370
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
371
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
372
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
373
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
374
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
375
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
376
377
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
378
379
};

380
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
381
382
383

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
384
385
386
387
388
389
390
391
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

392
393
394
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
395
396
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
397
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
398
399
400

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
401
 public:
Guolin Ke's avatar
Guolin Ke committed
402
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
403
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
404
405
406
407
408
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
409
410

 private:
Guolin Ke's avatar
Guolin Ke committed
411
412
413
414
415
416
417
418
419
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
420
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
421
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
422
423
}

Guolin Ke's avatar
Guolin Ke committed
424
int LGBM_DatasetCreateFromFile(const char* filename,
425
426
427
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
428
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
429
430
  auto param = Config::Str2Map(parameters);
  Config config;
431
432
433
434
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
435
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
436
  if (reference == nullptr) {
437
438
439
440
441
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
442
  } else {
443
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
444
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
445
  }
446
  API_END();
Guolin Ke's avatar
Guolin Ke committed
447
448
}

449

Guolin Ke's avatar
Guolin Ke committed
450
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
451
452
453
454
455
456
457
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
458
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
459
460
  auto param = Config::Str2Map(parameters);
  Config config;
461
462
463
464
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
465
  DatasetLoader loader(config, nullptr, 1, nullptr);
466
467
468
469
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
470
471
}

472

Guolin Ke's avatar
Guolin Ke committed
473
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
474
475
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
476
477
478
479
480
481
482
483
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
484
int LGBM_DatasetPushRows(DatasetHandle dataset,
485
486
487
488
489
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
490
491
492
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
493
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
494
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
495
  for (int i = 0; i < nrow; ++i) {
496
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
497
498
499
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
500
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
501
  }
502
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
503
504
505
506
507
508
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
509
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
510
511
512
513
514
515
516
517
518
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
519
520
521
522
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
523
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
524
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
525
  for (int i = 0; i < nrow; ++i) {
526
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
527
528
529
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
530
                          static_cast<data_size_t>(start_row + i), one_row);
531
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
532
  }
533
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
534
535
536
537
538
539
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
540
int LGBM_DatasetCreateFromMat(const void* data,
541
542
543
544
545
546
547
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
569
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
570
571
  auto param = Config::Str2Map(parameters);
  Config config;
572
573
574
575
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
576
  std::unique_ptr<Dataset> ret;
577
578
579
580
581
582
583
584
585
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
586

Guolin Ke's avatar
Guolin Ke committed
587
588
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
589
    Random rand(config.data_random_seed);
590
591
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
592
    sample_cnt = static_cast<int>(sample_indices.size());
593
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
594
    std::vector<std::vector<int>> sample_idx(ncol);
595
596
597

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
598
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
599
      auto idx = sample_indices[i];
600
601
602
603
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
604

605
606
607
608
609
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
610
        }
Guolin Ke's avatar
Guolin Ke committed
611
612
      }
    }
Guolin Ke's avatar
Guolin Ke committed
613
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
614
615
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
616
617
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
618
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
619
  } else {
620
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
621
    ret->CreateValid(
622
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
623
  }
624
625
626
627
628
629
630
631
632
633
634
635
636
637
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
638
639
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
640
  *out = ret.release();
641
  API_END();
642
643
}

Guolin Ke's avatar
Guolin Ke committed
644
int LGBM_DatasetCreateFromCSR(const void* indptr,
645
646
647
648
649
650
651
652
653
654
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
655
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
656
657
  auto param = Config::Str2Map(parameters);
  Config config;
658
659
660
661
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
662
  std::unique_ptr<Dataset> ret;
663
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
664
665
666
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
667
668
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
669
    auto sample_indices = rand.Sample(nrow, sample_cnt);
670
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
671
672
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
673
674
675
676
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
677
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
678
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
679
680
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
681
682
683
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
684
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
685
686
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
687
688
689
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
690
  } else {
691
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
692
    ret->CreateValid(
693
      reinterpret_cast<const Dataset*>(reference));
694
  }
695
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
696
  #pragma omp parallel for schedule(static)
697
  for (int i = 0; i < nindptr - 1; ++i) {
698
    OMP_LOOP_EX_BEGIN();
699
700
701
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
702
    OMP_LOOP_EX_END();
703
  }
704
  OMP_THROW_EX();
705
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
706
  *out = ret.release();
707
  API_END();
708
709
}

710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
                              int num_rows,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
  API_BEGIN();

  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);

  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
760

761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
            const int tid = omp_get_thread_num();
            get_row_fun(i, threadBuffer);

            ret->PushOneRow(tid, i, threadBuffer);
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
780
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
781
782
783
784
785
786
787
788
789
790
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
791
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
792
793
  auto param = Config::Str2Map(parameters);
  Config config;
794
795
796
797
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
798
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
799
800
801
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
802
803
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
804
    auto sample_indices = rand.Sample(nrow, sample_cnt);
805
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
806
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
807
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
808
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
809
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
810
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
811
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
812
813
814
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
815
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
816
817
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
818
819
        }
      }
820
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
821
    }
822
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
823
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
824
825
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
826
827
828
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
829
  } else {
830
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
831
    ret->CreateValid(
832
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
833
  }
834
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
835
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
836
  for (int i = 0; i < ncol_ptr - 1; ++i) {
837
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
838
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
839
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
840
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
841
842
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
843
844
845
846
847
848
849
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
850
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
851
    }
852
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
853
  }
854
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
855
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
856
  *out = ret.release();
857
  API_END();
Guolin Ke's avatar
Guolin Ke committed
858
859
}

Guolin Ke's avatar
Guolin Ke committed
860
int LGBM_DatasetGetSubset(
861
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
862
863
864
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
865
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
866
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
867
868
  auto param = Config::Str2Map(parameters);
  Config config;
869
870
871
872
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
873
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
874
  CHECK(num_used_row_indices > 0);
875
876
877
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
Guolin Ke's avatar
Guolin Ke committed
878
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
879
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
880
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
881
882
883
884
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
885
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
886
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
887
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
888
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
889
890
891
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
892
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
893
894
895
896
897
898
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
899
int LGBM_DatasetGetFeatureNames(
900
901
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
902
  int* num_feature_names) {
903
904
905
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
906
907
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
908
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
909
910
911
912
  }
  API_END();
}

913
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
914
int LGBM_DatasetFree(DatasetHandle handle) {
915
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
916
  delete reinterpret_cast<Dataset*>(handle);
917
  API_END();
918
919
}

Guolin Ke's avatar
Guolin Ke committed
920
int LGBM_DatasetSaveBinary(DatasetHandle handle,
921
                           const char* filename) {
922
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
923
924
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
925
  API_END();
926
927
}

928
929
930
931
932
933
934
935
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
936
int LGBM_DatasetSetField(DatasetHandle handle,
937
938
939
940
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
941
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
942
  auto dataset = reinterpret_cast<Dataset*>(handle);
943
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
944
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
945
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
946
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
947
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
948
949
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
950
  }
951
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
952
  API_END();
953
954
}

Guolin Ke's avatar
Guolin Ke committed
955
int LGBM_DatasetGetField(DatasetHandle handle,
956
957
958
959
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
960
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
961
  auto dataset = reinterpret_cast<Dataset*>(handle);
962
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
963
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
964
    *out_type = C_API_DTYPE_FLOAT32;
965
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
966
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
967
    *out_type = C_API_DTYPE_INT32;
968
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
969
970
971
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
972
  } else if (dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))) {
973
974
    *out_type = C_API_DTYPE_INT8;
    is_success = true;
975
  }
976
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
977
  if (*out_ptr == nullptr) { *out_len = 0; }
978
  API_END();
979
980
}

981
982
983
984
985
986
987
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
988
int LGBM_DatasetGetNumData(DatasetHandle handle,
989
                           int* out) {
990
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
991
992
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
993
  API_END();
994
995
}

Guolin Ke's avatar
Guolin Ke committed
996
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
997
                              int* out) {
998
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
999
1000
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1001
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1002
}
1003

1004
1005
1006
1007
1008
1009
1010
1011
1012
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
  target_d->addFeaturesFrom(source_d);
  API_END();
}

1013
1014
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1015
int LGBM_BoosterCreate(const DatasetHandle train_data,
1016
1017
                       const char* parameters,
                       BoosterHandle* out) {
1018
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1019
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1020
1021
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1022
  API_END();
1023
1024
}

Guolin Ke's avatar
Guolin Ke committed
1025
int LGBM_BoosterCreateFromModelfile(
1026
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1027
  int* out_num_iterations,
1028
  BoosterHandle* out) {
1029
  API_BEGIN();
wxchan's avatar
wxchan committed
1030
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1031
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1032
  *out = ret.release();
1033
  API_END();
1034
1035
}

Guolin Ke's avatar
Guolin Ke committed
1036
int LGBM_BoosterLoadModelFromString(
1037
1038
1039
1040
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1041
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1042
1043
1044
1045
1046
1047
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1048
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1049
int LGBM_BoosterFree(BoosterHandle handle) {
1050
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1051
  delete reinterpret_cast<Booster*>(handle);
1052
  API_END();
1053
1054
}

1055
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1056
1057
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1058
  ref_booster->ShuffleModels(start_iter, end_iter);
1059
1060
1061
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1062
int LGBM_BoosterMerge(BoosterHandle handle,
1063
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1064
1065
1066
1067
1068
1069
1070
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1071
int LGBM_BoosterAddValidData(BoosterHandle handle,
1072
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1073
1074
1075
1076
1077
1078
1079
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1080
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1081
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1082
1083
1084
1085
1086
1087
1088
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1089
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1090
1091
1092
1093
1094
1095
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1096
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1097
1098
1099
1100
1101
1102
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1103
1104
1105
1106
1107
1108
1109
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1110
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1111
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1112
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1113
1114
1115
1116
1117
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1118
  API_END();
1119
1120
}

Guolin Ke's avatar
Guolin Ke committed
1121
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1122
1123
1124
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1125
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1126
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1127
  #ifdef SCORE_T_USE_DOUBLE
1128
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1129
  #else
1130
1131
1132
1133
1134
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1135
  #endif
1136
  API_END();
1137
1138
}

Guolin Ke's avatar
Guolin Ke committed
1139
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1140
1141
1142
1143
1144
1145
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1146
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1147
1148
1149
1150
1151
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1152

1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1167
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1168
1169
1170
1171
1172
1173
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1174
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1175
1176
1177
1178
1179
1180
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1181
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1182
1183
1184
1185
1186
1187
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1188
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1189
1190
1191
1192
1193
1194
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1195
int LGBM_BoosterGetEval(BoosterHandle handle,
1196
1197
1198
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1199
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1200
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1201
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1202
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1203
  *out_len = static_cast<int>(result_buf.size());
1204
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1205
    (out_results)[i] = static_cast<double>(result_buf[i]);
1206
  }
1207
  API_END();
1208
1209
}

Guolin Ke's avatar
Guolin Ke committed
1210
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1211
1212
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1213
1214
1215
1216
1217
1218
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1219
int LGBM_BoosterGetPredict(BoosterHandle handle,
1220
1221
1222
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1223
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1224
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1225
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1226
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1227
1228
}

Guolin Ke's avatar
Guolin Ke committed
1229
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1230
1231
1232
1233
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1234
                               const char* parameter,
1235
                               const char* result_filename) {
1236
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1237
1238
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1239
1240
1241
1242
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1243
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1244
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1245
                       config, result_filename);
1246
  API_END();
1247
1248
}

Guolin Ke's avatar
Guolin Ke committed
1249
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1250
1251
1252
1253
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1254
1255
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1256
1257
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1258
1259
1260
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1261
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
1272
                              const char* parameter,
1273
1274
                              int64_t* out_len,
                              double* out_result) {
1275
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1276
1277
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1278
1279
1280
1281
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1282
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1283
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1284
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
1285
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1286
                       config, out_result, out_len);
1287
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1288
}
1289

1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
                              const char* parameter,
                              int64_t* out_len,
                              double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun,
                       config, out_result, out_len);
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1319
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1330
                              const char* parameter,
1331
1332
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1333
1334
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1335
1336
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1347
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1348
1349
1350
1351
1352
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1353
1354
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1355
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1356
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1357
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1358
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1359
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1360
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1361
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1362
1363
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1364
1365
    return one_row;
  };
Guolin Ke's avatar
Guolin Ke committed
1366
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config,
cbecker's avatar
cbecker committed
1367
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1368
1369
1370
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1371
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1372
1373
1374
1375
1376
1377
1378
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1379
                              const char* parameter,
1380
1381
                              int64_t* out_len,
                              double* out_result) {
1382
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1383
1384
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1385
1386
1387
1388
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1389
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1390
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1391
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1392
                       config, out_result, out_len);
1393
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1394
}
1395

1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
                              const void* data,
                              int data_type,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
                              const char* parameter,
                              int64_t* out_len,
                              double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun,
                       config, out_result, out_len);
  API_END();
}


1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
                       config, out_result, out_len);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1445
int LGBM_BoosterSaveModel(BoosterHandle handle,
1446
                          int start_iteration,
1447
1448
                          int num_iteration,
                          const char* filename) {
1449
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1450
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1451
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1452
1453
1454
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1455
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1456
                                  int start_iteration,
1457
                                  int num_iteration,
1458
                                  int64_t buffer_len,
1459
                                  int64_t* out_len,
1460
                                  char* out_str) {
1461
1462
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1463
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1464
  *out_len = static_cast<int64_t>(model.size()) + 1;
1465
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1466
    std::memcpy(out_str, model.c_str(), *out_len);
1467
1468
1469
1470
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1471
int LGBM_BoosterDumpModel(BoosterHandle handle,
1472
                          int start_iteration,
1473
                          int num_iteration,
1474
1475
                          int64_t buffer_len,
                          int64_t* out_len,
1476
                          char* out_str) {
wxchan's avatar
wxchan committed
1477
1478
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1479
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1480
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1481
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1482
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1483
  }
1484
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1485
}
1486

Guolin Ke's avatar
Guolin Ke committed
1487
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1488
1489
1490
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1491
1492
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1493
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1494
1495
1496
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1497
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1498
1499
1500
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1501
1502
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1503
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1504
1505
1506
  API_END();
}

1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1520
1521
1522
1523
1524
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1525
  Config config;
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1542
1543
1544
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1545
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1546
  if (num_machines > 1) {
1547
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1548
1549
1550
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1551

Guolin Ke's avatar
Guolin Ke committed
1552
// ---- start of some help functions
1553
1554
1555

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1556
  if (data_type == C_API_DTYPE_FLOAT32) {
1557
1558
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1559
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1560
        std::vector<double> ret(num_col);
1561
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1562
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1563
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1564
1565
1566
1567
        }
        return ret;
      };
    } else {
1568
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1569
        std::vector<double> ret(num_col);
1570
        for (int i = 0; i < num_col; ++i) {
1571
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1572
1573
1574
1575
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1576
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1577
1578
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1579
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1580
        std::vector<double> ret(num_col);
1581
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1582
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1583
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1584
1585
1586
1587
        }
        return ret;
      };
    } else {
1588
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1589
        std::vector<double> ret(num_col);
1590
        for (int i = 0; i < num_col; ++i) {
1591
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1592
1593
1594
1595
1596
        }
        return ret;
      };
    }
  }
1597
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1598
1599
1600
1601
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1602
1603
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1604
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1605
1606
1607
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1608
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1609
          ret.emplace_back(i, raw_values[i]);
1610
        }
Guolin Ke's avatar
Guolin Ke committed
1611
1612
1613
      }
      return ret;
    };
1614
  }
Guolin Ke's avatar
Guolin Ke committed
1615
  return nullptr;
1616
1617
}

1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1634
std::function<std::vector<std::pair<int, double>>(int idx)>
1635
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1636
  if (data_type == C_API_DTYPE_FLOAT32) {
1637
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1638
    if (indptr_type == C_API_DTYPE_INT32) {
1639
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1640
      return [=] (int idx) {
1641
1642
1643
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1644
1645
1646
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1647
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1648
          ret.emplace_back(indices[i], data_ptr[i]);
1649
1650
1651
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1652
    } else if (indptr_type == C_API_DTYPE_INT64) {
1653
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1654
      return [=] (int idx) {
1655
1656
1657
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1658
1659
1660
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1661
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1662
          ret.emplace_back(indices[i], data_ptr[i]);
1663
1664
1665
1666
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1667
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1668
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1669
    if (indptr_type == C_API_DTYPE_INT32) {
1670
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1671
      return [=] (int idx) {
1672
1673
1674
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1675
1676
1677
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1678
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1679
          ret.emplace_back(indices[i], data_ptr[i]);
1680
1681
1682
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1683
    } else if (indptr_type == C_API_DTYPE_INT64) {
1684
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1685
      return [=] (int idx) {
1686
1687
1688
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1689
1690
1691
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1692
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1693
          ret.emplace_back(indices[i], data_ptr[i]);
1694
1695
1696
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1697
1698
    }
  }
1699
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1700
1701
}

Guolin Ke's avatar
Guolin Ke committed
1702
std::function<std::pair<int, double>(int idx)>
1703
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1704
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1705
  if (data_type == C_API_DTYPE_FLOAT32) {
1706
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1707
    if (col_ptr_type == C_API_DTYPE_INT32) {
1708
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1709
1710
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1711
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1712
1713
1714
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1715
        }
Guolin Ke's avatar
Guolin Ke committed
1716
1717
1718
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1719
      };
Guolin Ke's avatar
Guolin Ke committed
1720
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1721
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1722
1723
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1724
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1725
1726
1727
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1728
        }
Guolin Ke's avatar
Guolin Ke committed
1729
1730
1731
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1732
      };
Guolin Ke's avatar
Guolin Ke committed
1733
    }
Guolin Ke's avatar
Guolin Ke committed
1734
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1735
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1736
    if (col_ptr_type == C_API_DTYPE_INT32) {
1737
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1738
1739
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1740
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1741
1742
1743
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1744
        }
Guolin Ke's avatar
Guolin Ke committed
1745
1746
1747
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1748
      };
Guolin Ke's avatar
Guolin Ke committed
1749
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1750
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1751
1752
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1753
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1754
1755
1756
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1757
        }
Guolin Ke's avatar
Guolin Ke committed
1758
1759
1760
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1761
      };
Guolin Ke's avatar
Guolin Ke committed
1762
1763
    }
  }
1764
  throw std::runtime_error("Unknown data type in CSC matrix");
1765
1766
}

Guolin Ke's avatar
Guolin Ke committed
1767
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1768
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1769
1770
1771
1772
1773
1774
1775
1776
1777
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1778
    }
Guolin Ke's avatar
Guolin Ke committed
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1796
    }
Guolin Ke's avatar
Guolin Ke committed
1797
1798
1799
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1800
  }
Guolin Ke's avatar
Guolin Ke committed
1801
}