c_api.cpp 63.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
2

Guolin Ke's avatar
Guolin Ke committed
3
4
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
5
6
7
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
8
#include <LightGBM/network.h>
9
10
11
12
13
14
15
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
16
17

#include <string>
18
19
#include <cstdio>
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
20
#include <memory>
wxchan's avatar
wxchan committed
21
#include <mutex>
22
23
#include <stdexcept>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
24

25
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
26

Guolin Ke's avatar
Guolin Ke committed
27
28
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

Guolin Ke's avatar
Guolin Ke committed
45
class Booster {
Nikita Titov's avatar
Nikita Titov committed
46
 public:
Guolin Ke's avatar
Guolin Ke committed
47
  explicit Booster(const char* filename) {
48
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
49
50
  }

Guolin Ke's avatar
Guolin Ke committed
51
  Booster(const Dataset* train_data,
52
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
53
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
54
    config_.Set(param);
55
56
57
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
58
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
59
    if (config_.input_model.size() > 0) {
60
61
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
62
    }
Guolin Ke's avatar
Guolin Ke committed
63

Guolin Ke's avatar
Guolin Ke committed
64
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
65

66
67
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
68
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
69
    if (config_.tree_learner == std::string("feature")) {
70
      Log::Fatal("Do not support feature parallel in c api");
71
    }
Guolin Ke's avatar
Guolin Ke committed
72
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
73
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
74
      config_.tree_learner = "serial";
75
    }
Guolin Ke's avatar
Guolin Ke committed
76
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
77
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
78
79
80
81
82
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
83
84
85
86
  }

  ~Booster() {
  }
87

88
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
89
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
90
91
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
92
93
94
95
96
97
98
99
100
101
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
102
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
103
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
104
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
110
111
112
113
114
115
116
117
118
119
120
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
121
122
123
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
124
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
125
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
126
    if (param.count("num_class")) {
127
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
128
    }
Guolin Ke's avatar
Guolin Ke committed
129
130
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
131
    }
Guolin Ke's avatar
Guolin Ke committed
132
    if (param.count("metric")) {
133
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
134
    }
Guolin Ke's avatar
Guolin Ke committed
135
136

    config_.Set(param);
137
138
139
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
140
141
142

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
143
144
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
152
153
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
154
    }
Guolin Ke's avatar
Guolin Ke committed
155

Guolin Ke's avatar
Guolin Ke committed
156
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
157
158
159
160
161
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
162
163
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
164
165
166
167
168
169
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
170
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
171
  }
Guolin Ke's avatar
Guolin Ke committed
172

173
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
174
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
175
    return boosting_->TrainOneIter(nullptr, nullptr);
176
177
  }

Guolin Ke's avatar
Guolin Ke committed
178
179
180
181
182
183
184
185
186
187
188
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

189
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
190
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
191
    return boosting_->TrainOneIter(gradients, hessians);
192
193
  }

wxchan's avatar
wxchan committed
194
195
196
197
198
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
  void PredictSingleRow(int num_iteration, int predict_type,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
    std::lock_guard<std::mutex> lock(mutex_);

    if (single_row_predictor_.get() == nullptr) {
      bool is_predict_leaf = false;
      bool is_raw_score = false;
      bool predict_contrib = false;
      if (predict_type == C_API_PREDICT_LEAF_INDEX) {
        is_predict_leaf = true;
      } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
        is_raw_score = true;
      } else if (predict_type == C_API_PREDICT_CONTRIB) {
        predict_contrib = true;
      } else {
        is_raw_score = false;
      }

      // TODO: config could be optimized away... (maybe using lambda callback?)
      single_row_predictor_.reset(new Predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin));
      single_row_num_pred_in_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
      single_row_predict_function_ = single_row_predictor_->GetPredictFunction();
    }

    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
    single_row_predict_function_(one_row, pred_wrt_ptr);

    *out_len = single_row_num_pred_in_one_row_;
  }


Guolin Ke's avatar
Guolin Ke committed
234
235
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
236
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
237
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
238
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
239
240
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
241
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
242
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
243
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
244
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
245
      is_raw_score = true;
246
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
247
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
248
249
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
250
    }
Guolin Ke's avatar
Guolin Ke committed
251

Guolin Ke's avatar
Guolin Ke committed
252
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
253
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
254
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
255
    auto pred_fun = predictor.GetPredictFunction();
256
257
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
258
    for (int i = 0; i < nrow; ++i) {
259
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
260
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
261
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
262
      pred_fun(one_row, pred_wrt_ptr);
263
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
264
    }
265
    OMP_THROW_EX();
266
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
267
268
269
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
270
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
271
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
272
273
274
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
275
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
276
277
278
279
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
280
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
281
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
282
283
284
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
285
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
286
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
287
288
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
289
290
  }

Guolin Ke's avatar
Guolin Ke committed
291
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
292
293
294
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

295
296
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
297
  }
298

299
  void LoadModelFromString(const char* model_str) {
300
301
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
302
303
  }

304
305
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
306
307
  }

308
  std::string DumpModel(int start_iteration, int num_iteration) {
309
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
310
  }
311

312
313
314
315
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
316
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
317
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
318
319
320
321
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
322
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
323
324
  }

325
  void ShuffleModels(int start_iter, int end_iter) {
326
    std::lock_guard<std::mutex> lock(mutex_);
327
    boosting_->ShuffleModels(start_iter, end_iter);
328
329
  }

wxchan's avatar
wxchan committed
330
331
332
333
334
335
336
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
337

wxchan's avatar
wxchan committed
338
339
340
341
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
342
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
343
344
345
346
347
348
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
349
350
351
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
352
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
353
354
355
356
357
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
358
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
359

Nikita Titov's avatar
Nikita Titov committed
360
 private:
wxchan's avatar
wxchan committed
361
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
362
  std::unique_ptr<Boosting> boosting_;
363
364
365
366
  std::unique_ptr<Predictor> single_row_predictor_;
  PredictFunction single_row_predict_function_;
  int64_t single_row_num_pred_in_one_row_;

Guolin Ke's avatar
Guolin Ke committed
367
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
368
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
369
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
370
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
371
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
372
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
373
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
374
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
375
376
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
377
378
};

379
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
380
381
382

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
383
384
385
386
387
388
389
390
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

391
392
393
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
394
395
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
396
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
397
398
399

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
400
 public:
Guolin Ke's avatar
Guolin Ke committed
401
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
402
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
403
404
405
406
407
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
408
409

 private:
Guolin Ke's avatar
Guolin Ke committed
410
411
412
413
414
415
416
417
418
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
419
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
420
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
421
422
}

Guolin Ke's avatar
Guolin Ke committed
423
int LGBM_DatasetCreateFromFile(const char* filename,
424
425
426
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
427
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
428
429
  auto param = Config::Str2Map(parameters);
  Config config;
430
431
432
433
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
434
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
435
  if (reference == nullptr) {
436
437
438
439
440
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
441
  } else {
442
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
443
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
444
  }
445
  API_END();
Guolin Ke's avatar
Guolin Ke committed
446
447
}

448

Guolin Ke's avatar
Guolin Ke committed
449
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
450
451
452
453
454
455
456
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
457
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
458
459
  auto param = Config::Str2Map(parameters);
  Config config;
460
461
462
463
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
464
  DatasetLoader loader(config, nullptr, 1, nullptr);
465
466
467
468
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
469
470
}

471

Guolin Ke's avatar
Guolin Ke committed
472
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
473
474
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
475
476
477
478
479
480
481
482
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
483
int LGBM_DatasetPushRows(DatasetHandle dataset,
484
485
486
487
488
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
489
490
491
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
492
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
493
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
494
  for (int i = 0; i < nrow; ++i) {
495
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
496
497
498
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
499
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
500
  }
501
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
502
503
504
505
506
507
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
508
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
509
510
511
512
513
514
515
516
517
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
518
519
520
521
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
522
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
523
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
524
  for (int i = 0; i < nrow; ++i) {
525
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
526
527
528
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
529
                          static_cast<data_size_t>(start_row + i), one_row);
530
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
531
  }
532
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
533
534
535
536
537
538
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
539
int LGBM_DatasetCreateFromMat(const void* data,
540
541
542
543
544
545
546
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
568
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
569
570
  auto param = Config::Str2Map(parameters);
  Config config;
571
572
573
574
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
575
  std::unique_ptr<Dataset> ret;
576
577
578
579
580
581
582
583
584
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
585

Guolin Ke's avatar
Guolin Ke committed
586
587
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
588
    Random rand(config.data_random_seed);
589
590
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
591
    sample_cnt = static_cast<int>(sample_indices.size());
592
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
593
    std::vector<std::vector<int>> sample_idx(ncol);
594
595
596

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
597
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
598
      auto idx = sample_indices[i];
599
600
601
602
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
603

604
605
606
607
608
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
609
        }
Guolin Ke's avatar
Guolin Ke committed
610
611
      }
    }
Guolin Ke's avatar
Guolin Ke committed
612
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
613
614
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
615
616
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
617
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
618
  } else {
619
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
620
    ret->CreateValid(
621
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
622
  }
623
624
625
626
627
628
629
630
631
632
633
634
635
636
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
637
638
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
639
  *out = ret.release();
640
  API_END();
641
642
}

Guolin Ke's avatar
Guolin Ke committed
643
int LGBM_DatasetCreateFromCSR(const void* indptr,
644
645
646
647
648
649
650
651
652
653
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
654
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
655
656
  auto param = Config::Str2Map(parameters);
  Config config;
657
658
659
660
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
661
  std::unique_ptr<Dataset> ret;
662
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
663
664
665
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
666
667
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
668
    auto sample_indices = rand.Sample(nrow, sample_cnt);
669
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
670
671
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
672
673
674
675
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
676
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
677
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
678
679
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
680
681
682
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
683
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
684
685
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
686
687
688
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
689
  } else {
690
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
691
    ret->CreateValid(
692
      reinterpret_cast<const Dataset*>(reference));
693
  }
694
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
695
  #pragma omp parallel for schedule(static)
696
  for (int i = 0; i < nindptr - 1; ++i) {
697
    OMP_LOOP_EX_BEGIN();
698
699
700
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
701
    OMP_LOOP_EX_END();
702
  }
703
  OMP_THROW_EX();
704
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
705
  *out = ret.release();
706
  API_END();
707
708
}

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
                              int num_rows,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
  API_BEGIN();

  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);

  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
759

760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
            const int tid = omp_get_thread_num();
            get_row_fun(i, threadBuffer);

            ret->PushOneRow(tid, i, threadBuffer);
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
779
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
780
781
782
783
784
785
786
787
788
789
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
790
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
791
792
  auto param = Config::Str2Map(parameters);
  Config config;
793
794
795
796
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
797
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
798
799
800
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
801
802
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
803
    auto sample_indices = rand.Sample(nrow, sample_cnt);
804
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
805
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
806
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
807
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
808
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
809
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
810
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
811
812
813
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
814
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
815
816
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
817
818
        }
      }
819
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
820
    }
821
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
822
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
823
824
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
825
826
827
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
828
  } else {
829
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
830
    ret->CreateValid(
831
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
832
  }
833
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
834
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
835
  for (int i = 0; i < ncol_ptr - 1; ++i) {
836
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
837
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
838
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
839
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
840
841
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
842
843
844
845
846
847
848
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
849
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
850
    }
851
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
852
  }
853
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
854
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
855
  *out = ret.release();
856
  API_END();
Guolin Ke's avatar
Guolin Ke committed
857
858
}

Guolin Ke's avatar
Guolin Ke committed
859
int LGBM_DatasetGetSubset(
860
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
861
862
863
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
864
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
865
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
866
867
  auto param = Config::Str2Map(parameters);
  Config config;
868
869
870
871
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
872
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
873
  CHECK(num_used_row_indices > 0);
874
875
876
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
Guolin Ke's avatar
Guolin Ke committed
877
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
878
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
879
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
880
881
882
883
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
884
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
885
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
886
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
887
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
888
889
890
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
891
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
892
893
894
895
896
897
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
898
int LGBM_DatasetGetFeatureNames(
899
900
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
901
  int* num_feature_names) {
902
903
904
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
905
906
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
907
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
908
909
910
911
  }
  API_END();
}

912
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
913
int LGBM_DatasetFree(DatasetHandle handle) {
914
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
915
  delete reinterpret_cast<Dataset*>(handle);
916
  API_END();
917
918
}

Guolin Ke's avatar
Guolin Ke committed
919
int LGBM_DatasetSaveBinary(DatasetHandle handle,
920
                           const char* filename) {
921
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
922
923
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
924
  API_END();
925
926
}

927
928
929
930
931
932
933
934
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
935
int LGBM_DatasetSetField(DatasetHandle handle,
936
937
938
939
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
940
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
941
  auto dataset = reinterpret_cast<Dataset*>(handle);
942
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
943
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
944
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
945
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
946
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
947
948
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
949
  }
950
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
951
  API_END();
952
953
}

Guolin Ke's avatar
Guolin Ke committed
954
int LGBM_DatasetGetField(DatasetHandle handle,
955
956
957
958
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
959
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
960
  auto dataset = reinterpret_cast<Dataset*>(handle);
961
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
962
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
963
    *out_type = C_API_DTYPE_FLOAT32;
964
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
965
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
966
    *out_type = C_API_DTYPE_INT32;
967
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
968
969
970
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
971
  } else if (dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))) {
972
973
    *out_type = C_API_DTYPE_INT8;
    is_success = true;
974
  }
975
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
976
  if (*out_ptr == nullptr) { *out_len = 0; }
977
  API_END();
978
979
}

980
981
982
983
984
985
986
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
987
int LGBM_DatasetGetNumData(DatasetHandle handle,
988
                           int* out) {
989
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
990
991
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
992
  API_END();
993
994
}

Guolin Ke's avatar
Guolin Ke committed
995
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
996
                              int* out) {
997
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
998
999
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1000
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1001
}
1002

1003
1004
1005
1006
1007
1008
1009
1010
1011
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
  target_d->addFeaturesFrom(source_d);
  API_END();
}

1012
1013
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1014
int LGBM_BoosterCreate(const DatasetHandle train_data,
1015
1016
                       const char* parameters,
                       BoosterHandle* out) {
1017
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1018
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1019
1020
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1021
  API_END();
1022
1023
}

Guolin Ke's avatar
Guolin Ke committed
1024
int LGBM_BoosterCreateFromModelfile(
1025
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1026
  int* out_num_iterations,
1027
  BoosterHandle* out) {
1028
  API_BEGIN();
wxchan's avatar
wxchan committed
1029
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1030
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1031
  *out = ret.release();
1032
  API_END();
1033
1034
}

Guolin Ke's avatar
Guolin Ke committed
1035
int LGBM_BoosterLoadModelFromString(
1036
1037
1038
1039
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1040
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1041
1042
1043
1044
1045
1046
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1047
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1048
int LGBM_BoosterFree(BoosterHandle handle) {
1049
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1050
  delete reinterpret_cast<Booster*>(handle);
1051
  API_END();
1052
1053
}

1054
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1055
1056
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1057
  ref_booster->ShuffleModels(start_iter, end_iter);
1058
1059
1060
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1061
int LGBM_BoosterMerge(BoosterHandle handle,
1062
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1063
1064
1065
1066
1067
1068
1069
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1070
int LGBM_BoosterAddValidData(BoosterHandle handle,
1071
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1072
1073
1074
1075
1076
1077
1078
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1079
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1080
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1081
1082
1083
1084
1085
1086
1087
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1088
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1089
1090
1091
1092
1093
1094
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1095
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1096
1097
1098
1099
1100
1101
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1102
1103
1104
1105
1106
1107
1108
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1109
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1110
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1111
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1112
1113
1114
1115
1116
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1117
  API_END();
1118
1119
}

Guolin Ke's avatar
Guolin Ke committed
1120
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1121
1122
1123
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1124
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1125
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1126
  #ifdef SCORE_T_USE_DOUBLE
1127
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1128
  #else
1129
1130
1131
1132
1133
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1134
  #endif
1135
  API_END();
1136
1137
}

Guolin Ke's avatar
Guolin Ke committed
1138
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1139
1140
1141
1142
1143
1144
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1145
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1146
1147
1148
1149
1150
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1151

1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1166
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1167
1168
1169
1170
1171
1172
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1173
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1174
1175
1176
1177
1178
1179
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1180
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1181
1182
1183
1184
1185
1186
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1187
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1188
1189
1190
1191
1192
1193
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1194
int LGBM_BoosterGetEval(BoosterHandle handle,
1195
1196
1197
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1198
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1199
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1200
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1201
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1202
  *out_len = static_cast<int>(result_buf.size());
1203
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1204
    (out_results)[i] = static_cast<double>(result_buf[i]);
1205
  }
1206
  API_END();
1207
1208
}

Guolin Ke's avatar
Guolin Ke committed
1209
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1210
1211
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1212
1213
1214
1215
1216
1217
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1218
int LGBM_BoosterGetPredict(BoosterHandle handle,
1219
1220
1221
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1222
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1223
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1224
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1225
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1226
1227
}

Guolin Ke's avatar
Guolin Ke committed
1228
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1229
1230
1231
1232
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1233
                               const char* parameter,
1234
                               const char* result_filename) {
1235
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1236
1237
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1238
1239
1240
1241
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1242
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1243
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1244
                       config, result_filename);
1245
  API_END();
1246
1247
}

Guolin Ke's avatar
Guolin Ke committed
1248
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1249
1250
1251
1252
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1253
1254
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1255
1256
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1257
1258
1259
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1260
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
1271
                              const char* parameter,
1272
1273
                              int64_t* out_len,
                              double* out_result) {
1274
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1275
1276
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1277
1278
1279
1280
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1281
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1282
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1283
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
1284
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1285
                       config, out_result, out_len);
1286
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1287
}
1288

1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
                              const char* parameter,
                              int64_t* out_len,
                              double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun,
                       config, out_result, out_len);
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1318
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1329
                              const char* parameter,
1330
1331
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1332
1333
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1334
1335
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1346
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1347
1348
1349
1350
1351
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1352
1353
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1354
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1355
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1356
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1357
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1358
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1359
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1360
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1361
1362
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1363
1364
    return one_row;
  };
Guolin Ke's avatar
Guolin Ke committed
1365
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config,
cbecker's avatar
cbecker committed
1366
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1367
1368
1369
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1370
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1371
1372
1373
1374
1375
1376
1377
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1378
                              const char* parameter,
1379
1380
                              int64_t* out_len,
                              double* out_result) {
1381
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1382
1383
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1384
1385
1386
1387
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1388
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1389
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1390
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1391
                       config, out_result, out_len);
1392
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1393
}
1394

1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
                              const void* data,
                              int data_type,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
                              const char* parameter,
                              int64_t* out_len,
                              double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun,
                       config, out_result, out_len);
  API_END();
}


1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
                       config, out_result, out_len);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1444
int LGBM_BoosterSaveModel(BoosterHandle handle,
1445
                          int start_iteration,
1446
1447
                          int num_iteration,
                          const char* filename) {
1448
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1449
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1450
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1451
1452
1453
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1454
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1455
                                  int start_iteration,
1456
                                  int num_iteration,
1457
                                  int64_t buffer_len,
1458
                                  int64_t* out_len,
1459
                                  char* out_str) {
1460
1461
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1462
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1463
  *out_len = static_cast<int64_t>(model.size()) + 1;
1464
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1465
    std::memcpy(out_str, model.c_str(), *out_len);
1466
1467
1468
1469
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1470
int LGBM_BoosterDumpModel(BoosterHandle handle,
1471
                          int start_iteration,
1472
                          int num_iteration,
1473
1474
                          int64_t buffer_len,
                          int64_t* out_len,
1475
                          char* out_str) {
wxchan's avatar
wxchan committed
1476
1477
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1478
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1479
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1480
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1481
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1482
  }
1483
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1484
}
1485

Guolin Ke's avatar
Guolin Ke committed
1486
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1487
1488
1489
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1490
1491
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1492
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1493
1494
1495
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1496
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1497
1498
1499
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1500
1501
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1502
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1503
1504
1505
  API_END();
}

1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1519
1520
1521
1522
1523
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1524
  Config config;
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1541
1542
1543
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1544
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1545
  if (num_machines > 1) {
1546
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1547
1548
1549
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1550

Guolin Ke's avatar
Guolin Ke committed
1551
// ---- start of some help functions
1552
1553
1554

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1555
  if (data_type == C_API_DTYPE_FLOAT32) {
1556
1557
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1558
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1559
        std::vector<double> ret(num_col);
1560
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1561
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1562
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1563
1564
1565
1566
        }
        return ret;
      };
    } else {
1567
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1568
        std::vector<double> ret(num_col);
1569
        for (int i = 0; i < num_col; ++i) {
1570
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1571
1572
1573
1574
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1575
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1576
1577
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1578
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1579
        std::vector<double> ret(num_col);
1580
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1581
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1582
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1583
1584
1585
1586
        }
        return ret;
      };
    } else {
1587
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1588
        std::vector<double> ret(num_col);
1589
        for (int i = 0; i < num_col; ++i) {
1590
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1591
1592
1593
1594
1595
        }
        return ret;
      };
    }
  }
1596
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1597
1598
1599
1600
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1601
1602
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1603
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1604
1605
1606
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1607
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1608
          ret.emplace_back(i, raw_values[i]);
1609
        }
Guolin Ke's avatar
Guolin Ke committed
1610
1611
1612
      }
      return ret;
    };
1613
  }
Guolin Ke's avatar
Guolin Ke committed
1614
  return nullptr;
1615
1616
}

1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1633
std::function<std::vector<std::pair<int, double>>(int idx)>
1634
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1635
  if (data_type == C_API_DTYPE_FLOAT32) {
1636
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1637
    if (indptr_type == C_API_DTYPE_INT32) {
1638
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1639
      return [=] (int idx) {
1640
1641
1642
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1643
1644
1645
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1646
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1647
          ret.emplace_back(indices[i], data_ptr[i]);
1648
1649
1650
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1651
    } else if (indptr_type == C_API_DTYPE_INT64) {
1652
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1653
      return [=] (int idx) {
1654
1655
1656
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1657
1658
1659
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1660
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1661
          ret.emplace_back(indices[i], data_ptr[i]);
1662
1663
1664
1665
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1666
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1667
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1668
    if (indptr_type == C_API_DTYPE_INT32) {
1669
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1670
      return [=] (int idx) {
1671
1672
1673
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1674
1675
1676
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1677
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1678
          ret.emplace_back(indices[i], data_ptr[i]);
1679
1680
1681
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1682
    } else if (indptr_type == C_API_DTYPE_INT64) {
1683
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1684
      return [=] (int idx) {
1685
1686
1687
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1688
1689
1690
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1691
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1692
          ret.emplace_back(indices[i], data_ptr[i]);
1693
1694
1695
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1696
1697
    }
  }
1698
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1699
1700
}

Guolin Ke's avatar
Guolin Ke committed
1701
std::function<std::pair<int, double>(int idx)>
1702
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1703
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1704
  if (data_type == C_API_DTYPE_FLOAT32) {
1705
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1706
    if (col_ptr_type == C_API_DTYPE_INT32) {
1707
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1708
1709
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1710
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1711
1712
1713
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1714
        }
Guolin Ke's avatar
Guolin Ke committed
1715
1716
1717
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1718
      };
Guolin Ke's avatar
Guolin Ke committed
1719
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1720
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1721
1722
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1723
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1724
1725
1726
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1727
        }
Guolin Ke's avatar
Guolin Ke committed
1728
1729
1730
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1731
      };
Guolin Ke's avatar
Guolin Ke committed
1732
    }
Guolin Ke's avatar
Guolin Ke committed
1733
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1734
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1735
    if (col_ptr_type == C_API_DTYPE_INT32) {
1736
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1737
1738
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1739
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1740
1741
1742
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1743
        }
Guolin Ke's avatar
Guolin Ke committed
1744
1745
1746
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1747
      };
Guolin Ke's avatar
Guolin Ke committed
1748
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1749
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1750
1751
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1752
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1753
1754
1755
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1756
        }
Guolin Ke's avatar
Guolin Ke committed
1757
1758
1759
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1760
      };
Guolin Ke's avatar
Guolin Ke committed
1761
1762
    }
  }
1763
  throw std::runtime_error("Unknown data type in CSC matrix");
1764
1765
}

Guolin Ke's avatar
Guolin Ke committed
1766
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1767
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1768
1769
1770
1771
1772
1773
1774
1775
1776
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1777
    }
Guolin Ke's avatar
Guolin Ke committed
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1795
    }
Guolin Ke's avatar
Guolin Ke committed
1796
1797
1798
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1799
  }
Guolin Ke's avatar
Guolin Ke committed
1800
}