c_api.cpp 63.9 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
2

Guolin Ke's avatar
Guolin Ke committed
3
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
4
5
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/random.h>
Guolin Ke's avatar
Guolin Ke committed
6
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/utils/log.h>
Guolin Ke's avatar
Guolin Ke committed
8
#include <LightGBM/dataset_loader.h>
Guolin Ke's avatar
Guolin Ke committed
9
10
11
12
13
#include <LightGBM/dataset.h>
#include <LightGBM/boosting.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/metric.h>
#include <LightGBM/config.h>
cbecker's avatar
cbecker committed
14
#include <LightGBM/prediction_early_stop.h>
15
#include <LightGBM/network.h>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
19
20

#include <cstdio>
#include <vector>
#include <string>
#include <cstring>
Guolin Ke's avatar
Guolin Ke committed
21
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
22
#include <stdexcept>
wxchan's avatar
wxchan committed
23
#include <mutex>
Guolin Ke's avatar
Guolin Ke committed
24
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
25

Guolin Ke's avatar
Guolin Ke committed
26
27
#include "./application/predictor.hpp"

Guolin Ke's avatar
Guolin Ke committed
28
29
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

Guolin Ke's avatar
Guolin Ke committed
46
class Booster {
Nikita Titov's avatar
Nikita Titov committed
47
 public:
Guolin Ke's avatar
Guolin Ke committed
48
  explicit Booster(const char* filename) {
49
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
50
51
  }

Guolin Ke's avatar
Guolin Ke committed
52
  Booster(const Dataset* train_data,
53
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
54
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
55
    config_.Set(param);
56
57
58
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
59
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
60
    if (config_.input_model.size() > 0) {
61
62
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
63
    }
Guolin Ke's avatar
Guolin Ke committed
64

Guolin Ke's avatar
Guolin Ke committed
65
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
66

67
68
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
69
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
70
    if (config_.tree_learner == std::string("feature")) {
71
      Log::Fatal("Do not support feature parallel in c api");
72
    }
Guolin Ke's avatar
Guolin Ke committed
73
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
74
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
75
      config_.tree_learner = "serial";
76
    }
Guolin Ke's avatar
Guolin Ke committed
77
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
78
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
79
80
81
82
83
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
84
85
86
87
  }

  ~Booster() {
  }
88

89
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
90
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
91
92
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
97
98
99
100
101
102
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
103
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
104
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
105
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
111
112
113
114
115
116
117
118
119
120
121
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
122
123
124
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
125
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
126
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
127
    if (param.count("num_class")) {
128
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
129
    }
Guolin Ke's avatar
Guolin Ke committed
130
131
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
132
    }
Guolin Ke's avatar
Guolin Ke committed
133
    if (param.count("metric")) {
134
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
135
    }
Guolin Ke's avatar
Guolin Ke committed
136
137

    config_.Set(param);
138
139
140
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
141
142
143

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
144
145
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
146
147
148
149
150
151
152
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
153
154
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
155
    }
Guolin Ke's avatar
Guolin Ke committed
156

Guolin Ke's avatar
Guolin Ke committed
157
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
158
159
160
161
162
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
163
164
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
165
166
167
168
169
170
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
171
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
172
  }
Guolin Ke's avatar
Guolin Ke committed
173

174
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
175
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
176
    return boosting_->TrainOneIter(nullptr, nullptr);
177
178
  }

Guolin Ke's avatar
Guolin Ke committed
179
180
181
182
183
184
185
186
187
188
189
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

190
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
191
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
192
    return boosting_->TrainOneIter(gradients, hessians);
193
194
  }

wxchan's avatar
wxchan committed
195
196
197
198
199
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
  void PredictSingleRow(int num_iteration, int predict_type,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
    std::lock_guard<std::mutex> lock(mutex_);

    if (single_row_predictor_.get() == nullptr) {
      bool is_predict_leaf = false;
      bool is_raw_score = false;
      bool predict_contrib = false;
      if (predict_type == C_API_PREDICT_LEAF_INDEX) {
        is_predict_leaf = true;
      } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
        is_raw_score = true;
      } else if (predict_type == C_API_PREDICT_CONTRIB) {
        predict_contrib = true;
      } else {
        is_raw_score = false;
      }

      // TODO: config could be optimized away... (maybe using lambda callback?)
      single_row_predictor_.reset(new Predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin));
      single_row_num_pred_in_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
      single_row_predict_function_ = single_row_predictor_->GetPredictFunction();
    }

    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
    single_row_predict_function_(one_row, pred_wrt_ptr);

    *out_len = single_row_num_pred_in_one_row_;
  }


Guolin Ke's avatar
Guolin Ke committed
235
236
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
237
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
238
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
239
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
240
241
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
242
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
243
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
244
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
245
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
246
      is_raw_score = true;
247
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
248
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
249
250
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
251
    }
Guolin Ke's avatar
Guolin Ke committed
252

Guolin Ke's avatar
Guolin Ke committed
253
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
254
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
255
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
256
    auto pred_fun = predictor.GetPredictFunction();
257
258
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
259
    for (int i = 0; i < nrow; ++i) {
260
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
261
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
262
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
263
      pred_fun(one_row, pred_wrt_ptr);
264
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
265
    }
266
    OMP_THROW_EX();
267
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
268
269
270
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
271
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
272
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
273
274
275
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
276
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
277
278
279
280
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
281
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
282
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
283
284
285
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
286
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
287
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
288
289
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
290
291
  }

Guolin Ke's avatar
Guolin Ke committed
292
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
293
294
295
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

296
297
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
298
  }
299

300
  void LoadModelFromString(const char* model_str) {
301
302
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
303
304
  }

305
306
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
307
308
  }

309
  std::string DumpModel(int start_iteration, int num_iteration) {
310
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
311
  }
312

313
314
315
316
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
317
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
318
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
323
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
324
325
  }

326
  void ShuffleModels(int start_iter, int end_iter) {
327
    std::lock_guard<std::mutex> lock(mutex_);
328
    boosting_->ShuffleModels(start_iter, end_iter);
329
330
  }

wxchan's avatar
wxchan committed
331
332
333
334
335
336
337
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
338

wxchan's avatar
wxchan committed
339
340
341
342
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
343
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
344
345
346
347
348
349
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
350
351
352
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
353
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
354
355
356
357
358
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
359
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
360

Nikita Titov's avatar
Nikita Titov committed
361
 private:
wxchan's avatar
wxchan committed
362
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
363
  std::unique_ptr<Boosting> boosting_;
364
365
366
367
  std::unique_ptr<Predictor> single_row_predictor_;
  PredictFunction single_row_predict_function_;
  int64_t single_row_num_pred_in_one_row_;

Guolin Ke's avatar
Guolin Ke committed
368
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
369
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
370
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
371
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
372
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
373
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
374
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
375
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
376
377
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
378
379
};

380
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
381
382
383

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
384
385
386
387
388
389
390
391
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

392
393
394
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
395
396
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
397
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
398
399
400

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
401
 public:
Guolin Ke's avatar
Guolin Ke committed
402
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
403
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
404
405
406
407
408
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
409
410

 private:
Guolin Ke's avatar
Guolin Ke committed
411
412
413
414
415
416
417
418
419
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
420
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
421
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
422
423
}

Guolin Ke's avatar
Guolin Ke committed
424
int LGBM_DatasetCreateFromFile(const char* filename,
425
426
427
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
428
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
429
430
  auto param = Config::Str2Map(parameters);
  Config config;
431
432
433
434
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
435
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
436
  if (reference == nullptr) {
437
438
439
440
441
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
442
  } else {
443
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
444
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
445
  }
446
  API_END();
Guolin Ke's avatar
Guolin Ke committed
447
448
}

449

Guolin Ke's avatar
Guolin Ke committed
450
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
451
452
453
454
455
456
457
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
458
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
459
460
  auto param = Config::Str2Map(parameters);
  Config config;
461
462
463
464
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
465
  DatasetLoader loader(config, nullptr, 1, nullptr);
466
467
468
469
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
470
471
}

472

Guolin Ke's avatar
Guolin Ke committed
473
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
474
475
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
476
477
478
479
480
481
482
483
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
484
int LGBM_DatasetPushRows(DatasetHandle dataset,
485
486
487
488
489
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
490
491
492
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
493
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
494
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
495
  for (int i = 0; i < nrow; ++i) {
496
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
497
498
499
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
500
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
501
  }
502
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
503
504
505
506
507
508
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
509
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
510
511
512
513
514
515
516
517
518
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
519
520
521
522
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
523
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
524
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
525
  for (int i = 0; i < nrow; ++i) {
526
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
527
528
529
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
530
                          static_cast<data_size_t>(start_row + i), one_row);
531
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
532
  }
533
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
534
535
536
537
538
539
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
540
int LGBM_DatasetCreateFromMat(const void* data,
541
542
543
544
545
546
547
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
569
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
570
571
  auto param = Config::Str2Map(parameters);
  Config config;
572
573
574
575
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
576
  std::unique_ptr<Dataset> ret;
577
578
579
580
581
582
583
584
585
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
586

Guolin Ke's avatar
Guolin Ke committed
587
588
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
589
    Random rand(config.data_random_seed);
590
591
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
592
    sample_cnt = static_cast<int>(sample_indices.size());
593
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
594
    std::vector<std::vector<int>> sample_idx(ncol);
595
596
597

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
598
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
599
      auto idx = sample_indices[i];
600
601
602
603
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
604

605
606
607
608
609
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
610
        }
Guolin Ke's avatar
Guolin Ke committed
611
612
      }
    }
Guolin Ke's avatar
Guolin Ke committed
613
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
614
615
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
616
617
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
618
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
619
  } else {
620
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
621
    ret->CreateValid(
622
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
623
  }
624
625
626
627
628
629
630
631
632
633
634
635
636
637
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
638
639
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
640
  *out = ret.release();
641
  API_END();
642
643
}

Guolin Ke's avatar
Guolin Ke committed
644
int LGBM_DatasetCreateFromCSR(const void* indptr,
645
646
647
648
649
650
651
652
653
654
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
655
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
656
657
  auto param = Config::Str2Map(parameters);
  Config config;
658
659
660
661
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
662
  std::unique_ptr<Dataset> ret;
663
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
664
665
666
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
667
668
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
669
    auto sample_indices = rand.Sample(nrow, sample_cnt);
670
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
671
672
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
673
674
675
676
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
677
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
678
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
679
680
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
681
682
683
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
684
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
685
686
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
687
688
689
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
690
  } else {
691
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
692
    ret->CreateValid(
693
      reinterpret_cast<const Dataset*>(reference));
694
  }
695
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
696
  #pragma omp parallel for schedule(static)
697
  for (int i = 0; i < nindptr - 1; ++i) {
698
    OMP_LOOP_EX_BEGIN();
699
700
701
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
702
    OMP_LOOP_EX_END();
703
  }
704
  OMP_THROW_EX();
705
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
706
  *out = ret.release();
707
  API_END();
708
709
}

710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
                              int num_rows,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {

  API_BEGIN();

  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);

  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
  
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
            const int tid = omp_get_thread_num();
            get_row_fun(i, threadBuffer);

            ret->PushOneRow(tid, i, threadBuffer);
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
781
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
782
783
784
785
786
787
788
789
790
791
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
792
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
793
794
  auto param = Config::Str2Map(parameters);
  Config config;
795
796
797
798
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
799
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
800
801
802
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
803
804
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
805
    auto sample_indices = rand.Sample(nrow, sample_cnt);
806
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
807
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
808
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
809
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
810
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
811
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
812
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
813
814
815
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
816
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
817
818
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
819
820
        }
      }
821
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
822
    }
823
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
824
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
825
826
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
827
828
829
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
830
  } else {
831
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
832
    ret->CreateValid(
833
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
834
  }
835
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
836
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
837
  for (int i = 0; i < ncol_ptr - 1; ++i) {
838
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
839
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
840
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
841
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
842
843
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
844
845
846
847
848
849
850
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
851
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
852
    }
853
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
854
  }
855
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
856
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
857
  *out = ret.release();
858
  API_END();
Guolin Ke's avatar
Guolin Ke committed
859
860
}

Guolin Ke's avatar
Guolin Ke committed
861
int LGBM_DatasetGetSubset(
862
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
863
864
865
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
866
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
867
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
868
869
  auto param = Config::Str2Map(parameters);
  Config config;
870
871
872
873
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
874
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
875
  CHECK(num_used_row_indices > 0);
876
877
878
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
Guolin Ke's avatar
Guolin Ke committed
879
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
880
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
881
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
882
883
884
885
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
886
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
887
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
888
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
889
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
890
891
892
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
893
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
894
895
896
897
898
899
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
900
int LGBM_DatasetGetFeatureNames(
901
902
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
903
  int* num_feature_names) {
904
905
906
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
907
908
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
909
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
910
911
912
913
  }
  API_END();
}

914
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
915
int LGBM_DatasetFree(DatasetHandle handle) {
916
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
917
  delete reinterpret_cast<Dataset*>(handle);
918
  API_END();
919
920
}

Guolin Ke's avatar
Guolin Ke committed
921
int LGBM_DatasetSaveBinary(DatasetHandle handle,
922
                           const char* filename) {
923
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
924
925
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
926
  API_END();
927
928
}

929
930
931
932
933
934
935
936
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
937
int LGBM_DatasetSetField(DatasetHandle handle,
938
939
940
941
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
942
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
943
  auto dataset = reinterpret_cast<Dataset*>(handle);
944
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
945
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
946
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
947
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
948
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
949
950
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
951
  }
952
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
953
  API_END();
954
955
}

Guolin Ke's avatar
Guolin Ke committed
956
int LGBM_DatasetGetField(DatasetHandle handle,
957
958
959
960
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
961
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
962
  auto dataset = reinterpret_cast<Dataset*>(handle);
963
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
964
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
965
    *out_type = C_API_DTYPE_FLOAT32;
966
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
967
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
968
    *out_type = C_API_DTYPE_INT32;
969
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
970
971
972
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
973
974
975
  } else if(dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))){
    *out_type = C_API_DTYPE_INT8;
    is_success = true;
976
  }
977
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
978
  if (*out_ptr == nullptr) { *out_len = 0; }
979
  API_END();
980
981
}

982
983
984
985
986
987
988
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
989
int LGBM_DatasetGetNumData(DatasetHandle handle,
990
                           int* out) {
991
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
992
993
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
994
  API_END();
995
996
}

Guolin Ke's avatar
Guolin Ke committed
997
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
998
                              int* out) {
999
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1000
1001
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1002
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1003
}
1004

1005
1006
1007
1008
1009
1010
1011
1012
1013
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
  target_d->addFeaturesFrom(source_d);
  API_END();
}

1014
1015
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1016
int LGBM_BoosterCreate(const DatasetHandle train_data,
1017
1018
                       const char* parameters,
                       BoosterHandle* out) {
1019
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1020
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1021
1022
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1023
  API_END();
1024
1025
}

Guolin Ke's avatar
Guolin Ke committed
1026
int LGBM_BoosterCreateFromModelfile(
1027
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1028
  int* out_num_iterations,
1029
  BoosterHandle* out) {
1030
  API_BEGIN();
wxchan's avatar
wxchan committed
1031
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1032
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1033
  *out = ret.release();
1034
  API_END();
1035
1036
}

Guolin Ke's avatar
Guolin Ke committed
1037
int LGBM_BoosterLoadModelFromString(
1038
1039
1040
1041
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1042
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1043
1044
1045
1046
1047
1048
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1049
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1050
int LGBM_BoosterFree(BoosterHandle handle) {
1051
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1052
  delete reinterpret_cast<Booster*>(handle);
1053
  API_END();
1054
1055
}

1056
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1057
1058
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1059
  ref_booster->ShuffleModels(start_iter, end_iter);
1060
1061
1062
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1063
int LGBM_BoosterMerge(BoosterHandle handle,
1064
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1065
1066
1067
1068
1069
1070
1071
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1072
int LGBM_BoosterAddValidData(BoosterHandle handle,
1073
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1074
1075
1076
1077
1078
1079
1080
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1081
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1082
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1083
1084
1085
1086
1087
1088
1089
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1090
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1091
1092
1093
1094
1095
1096
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1097
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1098
1099
1100
1101
1102
1103
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1104
1105
1106
1107
1108
1109
1110
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1111
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1112
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1113
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1114
1115
1116
1117
1118
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1119
  API_END();
1120
1121
}

Guolin Ke's avatar
Guolin Ke committed
1122
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1123
1124
1125
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1126
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1127
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1128
  #ifdef SCORE_T_USE_DOUBLE
1129
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1130
  #else
1131
1132
1133
1134
1135
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1136
  #endif
1137
  API_END();
1138
1139
}

Guolin Ke's avatar
Guolin Ke committed
1140
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1141
1142
1143
1144
1145
1146
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1147
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1148
1149
1150
1151
1152
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1153

1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1168
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1169
1170
1171
1172
1173
1174
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1175
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1176
1177
1178
1179
1180
1181
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1182
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1183
1184
1185
1186
1187
1188
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1189
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1190
1191
1192
1193
1194
1195
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1196
int LGBM_BoosterGetEval(BoosterHandle handle,
1197
1198
1199
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1200
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1201
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1202
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1203
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1204
  *out_len = static_cast<int>(result_buf.size());
1205
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1206
    (out_results)[i] = static_cast<double>(result_buf[i]);
1207
  }
1208
  API_END();
1209
1210
}

Guolin Ke's avatar
Guolin Ke committed
1211
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1212
1213
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1214
1215
1216
1217
1218
1219
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1220
int LGBM_BoosterGetPredict(BoosterHandle handle,
1221
1222
1223
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1224
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1225
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1226
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1227
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1228
1229
}

Guolin Ke's avatar
Guolin Ke committed
1230
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1231
1232
1233
1234
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1235
                               const char* parameter,
1236
                               const char* result_filename) {
1237
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1238
1239
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1240
1241
1242
1243
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1244
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1245
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1246
                       config, result_filename);
1247
  API_END();
1248
1249
}

Guolin Ke's avatar
Guolin Ke committed
1250
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1251
1252
1253
1254
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1255
1256
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1257
1258
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1259
1260
1261
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1262
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
1273
                              const char* parameter,
1274
1275
                              int64_t* out_len,
                              double* out_result) {
1276
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1277
1278
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1279
1280
1281
1282
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1283
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1284
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1285
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
1286
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1287
                       config, out_result, out_len);
1288
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1289
}
1290

1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
                              const char* parameter,
                              int64_t* out_len,
                              double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun,
                       config, out_result, out_len);
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1320
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1331
                              const char* parameter,
1332
1333
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1334
1335
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1336
1337
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1348
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1349
1350
1351
1352
1353
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1354
1355
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1356
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1357
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1358
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1359
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1360
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1361
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1362
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1363
1364
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1365
1366
    return one_row;
  };
Guolin Ke's avatar
Guolin Ke committed
1367
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config,
cbecker's avatar
cbecker committed
1368
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1369
1370
1371
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1372
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1373
1374
1375
1376
1377
1378
1379
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1380
                              const char* parameter,
1381
1382
                              int64_t* out_len,
                              double* out_result) {
1383
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1384
1385
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1386
1387
1388
1389
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1390
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1391
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1392
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1393
                       config, out_result, out_len);
1394
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1395
}
1396

1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
                              const void* data,
                              int data_type,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
                              const char* parameter,
                              int64_t* out_len,
                              double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun,
                       config, out_result, out_len);
  API_END();
}


1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
                       config, out_result, out_len);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1446
int LGBM_BoosterSaveModel(BoosterHandle handle,
1447
                          int start_iteration,
1448
1449
                          int num_iteration,
                          const char* filename) {
1450
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1451
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1452
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1453
1454
1455
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1456
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1457
                                  int start_iteration,
1458
                                  int num_iteration,
1459
                                  int64_t buffer_len,
1460
                                  int64_t* out_len,
1461
                                  char* out_str) {
1462
1463
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1464
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1465
  *out_len = static_cast<int64_t>(model.size()) + 1;
1466
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1467
    std::memcpy(out_str, model.c_str(), *out_len);
1468
1469
1470
1471
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1472
int LGBM_BoosterDumpModel(BoosterHandle handle,
1473
                          int start_iteration,
1474
                          int num_iteration,
1475
1476
                          int64_t buffer_len,
                          int64_t* out_len,
1477
                          char* out_str) {
wxchan's avatar
wxchan committed
1478
1479
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1480
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1481
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1482
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1483
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1484
  }
1485
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1486
}
1487

Guolin Ke's avatar
Guolin Ke committed
1488
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1489
1490
1491
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1492
1493
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1494
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1495
1496
1497
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1498
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1499
1500
1501
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1502
1503
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1504
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1505
1506
1507
  API_END();
}

1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1521
1522
1523
1524
1525
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1526
  Config config;
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1543
1544
1545
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1546
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1547
  if (num_machines > 1) {
1548
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1549
1550
1551
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1552

Guolin Ke's avatar
Guolin Ke committed
1553
// ---- start of some help functions
1554
1555
1556

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1557
  if (data_type == C_API_DTYPE_FLOAT32) {
1558
1559
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1560
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1561
        std::vector<double> ret(num_col);
1562
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1563
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1564
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1565
1566
1567
1568
        }
        return ret;
      };
    } else {
1569
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1570
        std::vector<double> ret(num_col);
1571
        for (int i = 0; i < num_col; ++i) {
1572
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1573
1574
1575
1576
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1577
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1578
1579
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1580
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1581
        std::vector<double> ret(num_col);
1582
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1583
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1584
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1585
1586
1587
1588
        }
        return ret;
      };
    } else {
1589
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1590
        std::vector<double> ret(num_col);
1591
        for (int i = 0; i < num_col; ++i) {
1592
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1593
1594
1595
1596
1597
        }
        return ret;
      };
    }
  }
1598
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1599
1600
1601
1602
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1603
1604
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1605
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1606
1607
1608
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1609
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1610
          ret.emplace_back(i, raw_values[i]);
1611
        }
Guolin Ke's avatar
Guolin Ke committed
1612
1613
1614
      }
      return ret;
    };
1615
  }
Guolin Ke's avatar
Guolin Ke committed
1616
  return nullptr;
1617
1618
}

1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1635
std::function<std::vector<std::pair<int, double>>(int idx)>
1636
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1637
  if (data_type == C_API_DTYPE_FLOAT32) {
1638
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1639
    if (indptr_type == C_API_DTYPE_INT32) {
1640
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1641
      return [=] (int idx) {
1642
1643
1644
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1645
1646
1647
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1648
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1649
          ret.emplace_back(indices[i], data_ptr[i]);
1650
1651
1652
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1653
    } else if (indptr_type == C_API_DTYPE_INT64) {
1654
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1655
      return [=] (int idx) {
1656
1657
1658
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1659
1660
1661
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1662
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1663
          ret.emplace_back(indices[i], data_ptr[i]);
1664
1665
1666
1667
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1668
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1669
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1670
    if (indptr_type == C_API_DTYPE_INT32) {
1671
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1672
      return [=] (int idx) {
1673
1674
1675
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1676
1677
1678
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1679
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1680
          ret.emplace_back(indices[i], data_ptr[i]);
1681
1682
1683
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1684
    } else if (indptr_type == C_API_DTYPE_INT64) {
1685
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1686
      return [=] (int idx) {
1687
1688
1689
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1690
1691
1692
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1693
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1694
          ret.emplace_back(indices[i], data_ptr[i]);
1695
1696
1697
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1698
1699
    }
  }
1700
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1701
1702
}

Guolin Ke's avatar
Guolin Ke committed
1703
std::function<std::pair<int, double>(int idx)>
1704
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1705
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1706
  if (data_type == C_API_DTYPE_FLOAT32) {
1707
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1708
    if (col_ptr_type == C_API_DTYPE_INT32) {
1709
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1710
1711
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1712
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1713
1714
1715
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1716
        }
Guolin Ke's avatar
Guolin Ke committed
1717
1718
1719
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1720
      };
Guolin Ke's avatar
Guolin Ke committed
1721
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1722
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1723
1724
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1725
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1726
1727
1728
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1729
        }
Guolin Ke's avatar
Guolin Ke committed
1730
1731
1732
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1733
      };
Guolin Ke's avatar
Guolin Ke committed
1734
    }
Guolin Ke's avatar
Guolin Ke committed
1735
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1736
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1737
    if (col_ptr_type == C_API_DTYPE_INT32) {
1738
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1739
1740
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1741
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1742
1743
1744
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1745
        }
Guolin Ke's avatar
Guolin Ke committed
1746
1747
1748
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1749
      };
Guolin Ke's avatar
Guolin Ke committed
1750
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1751
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1752
1753
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1754
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1755
1756
1757
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1758
        }
Guolin Ke's avatar
Guolin Ke committed
1759
1760
1761
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1762
      };
Guolin Ke's avatar
Guolin Ke committed
1763
1764
    }
  }
1765
  throw std::runtime_error("Unknown data type in CSC matrix");
1766
1767
}

Guolin Ke's avatar
Guolin Ke committed
1768
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1769
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1770
1771
1772
1773
1774
1775
1776
1777
1778
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1779
    }
Guolin Ke's avatar
Guolin Ke committed
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1797
    }
Guolin Ke's avatar
Guolin Ke committed
1798
1799
1800
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1801
  }
Guolin Ke's avatar
Guolin Ke committed
1802
}