c_api.cpp 65.2 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6

Guolin Ke's avatar
Guolin Ke committed
7
8
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
12
#include <LightGBM/network.h>
13
14
15
16
17
18
19
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
20
21

#include <string>
22
23
#include <cstdio>
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
24
#include <memory>
wxchan's avatar
wxchan committed
25
#include <mutex>
26
27
#include <stdexcept>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
28

29
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
30

Guolin Ke's avatar
Guolin Ke committed
31
32
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
const int PREDICTOR_TYPES = 4;

// Single row predictor to abstract away caching logic
class SingleRowPredictor {
 public:
  PredictFunction predict_function;
  int64_t num_pred_in_one_row;

  SingleRowPredictor(int predict_type, Boosting& boosting, const Config& config, int iter) {
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    bool predict_contrib = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      predict_contrib = true;
    } else {
      is_raw_score = false;
    }
    early_stop_ = config.pred_early_stop;
    early_stop_freq_ = config.pred_early_stop_freq;
    early_stop_margin_ = config.pred_early_stop_margin;
    iter_ = iter;
    predictor_.reset(new Predictor(&boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib,
                                   early_stop_, early_stop_freq_, early_stop_margin_));
    num_pred_in_one_row = boosting.NumPredictOneRow(iter_, is_predict_leaf, predict_contrib);
    predict_function = predictor_->GetPredictFunction();
    num_total_model_ = boosting.NumberOfTotalModel();
  }
  ~SingleRowPredictor() {}
  bool IsPredictorEqual(const Config& config, int iter, Boosting& boosting) {
    return early_stop_ != config.pred_early_stop ||
      early_stop_freq_ != config.pred_early_stop_freq ||
      early_stop_margin_ != config.pred_early_stop_margin ||
      iter_ != iter ||
      num_total_model_ != boosting.NumberOfTotalModel();
  }
  
 private:
  std::unique_ptr<Predictor> predictor_;
  bool early_stop_;
  int early_stop_freq_;
  double early_stop_margin_;
  int iter_;
  int num_total_model_;
};

Guolin Ke's avatar
Guolin Ke committed
98
class Booster {
Nikita Titov's avatar
Nikita Titov committed
99
 public:
Guolin Ke's avatar
Guolin Ke committed
100
  explicit Booster(const char* filename) {
101
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
102
103
  }

Guolin Ke's avatar
Guolin Ke committed
104
  Booster(const Dataset* train_data,
105
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
106
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
107
    config_.Set(param);
108
109
110
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
111
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
112
    if (config_.input_model.size() > 0) {
113
114
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
115
    }
Guolin Ke's avatar
Guolin Ke committed
116

Guolin Ke's avatar
Guolin Ke committed
117
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
118

119
120
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
121
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
122
    if (config_.tree_learner == std::string("feature")) {
123
      Log::Fatal("Do not support feature parallel in c api");
124
    }
Guolin Ke's avatar
Guolin Ke committed
125
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
126
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
127
      config_.tree_learner = "serial";
128
    }
Guolin Ke's avatar
Guolin Ke committed
129
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
130
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
131
132
133
134
135
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
  }

  ~Booster() {
  }
140

141
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
142
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
143
144
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
152
153
154
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
155
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
156
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
157
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
158
159
160
161
162
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
163
164
165
166
167
168
169
170
171
172
173
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
174
175
176
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
177
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
178
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
179
    if (param.count("num_class")) {
180
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
181
    }
Guolin Ke's avatar
Guolin Ke committed
182
183
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
184
    }
Guolin Ke's avatar
Guolin Ke committed
185
    if (param.count("metric")) {
186
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
187
    }
Guolin Ke's avatar
Guolin Ke committed
188
189

    config_.Set(param);
190
191
192
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
193
194
195

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
196
197
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
198
199
200
201
202
203
204
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
205
206
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
207
    }
Guolin Ke's avatar
Guolin Ke committed
208

Guolin Ke's avatar
Guolin Ke committed
209
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
210
211
212
213
214
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
215
216
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
217
218
219
220
221
222
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
223
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
224
  }
Guolin Ke's avatar
Guolin Ke committed
225

226
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
227
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
228
    return boosting_->TrainOneIter(nullptr, nullptr);
229
230
  }

Guolin Ke's avatar
Guolin Ke committed
231
232
233
234
235
236
237
238
239
240
241
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

242
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
243
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
244
    return boosting_->TrainOneIter(gradients, hessians);
245
246
  }

wxchan's avatar
wxchan committed
247
248
249
250
251
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

252
253
254
255
256
  void PredictSingleRow(int num_iteration, int predict_type,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
    std::lock_guard<std::mutex> lock(mutex_);
257
258
259
260
    if (single_row_predictor_[predict_type].get() == nullptr ||
        !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, *boosting_.get())) {
      single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, *boosting_.get(),
                                                                       config, num_iteration));
261
262
263
264
    }

    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
265
    single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
266

267
    *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
268
269
270
  }


Guolin Ke's avatar
Guolin Ke committed
271
272
  void Predict(int num_iteration, int predict_type, int nrow,
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
273
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
274
               double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
275
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
276
277
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
278
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
279
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
280
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
281
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
282
      is_raw_score = true;
283
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
284
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
285
286
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
287
    }
Guolin Ke's avatar
Guolin Ke committed
288

Guolin Ke's avatar
Guolin Ke committed
289
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
290
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
291
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
292
    auto pred_fun = predictor.GetPredictFunction();
293
294
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
295
    for (int i = 0; i < nrow; ++i) {
296
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
297
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
298
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
299
      pred_fun(one_row, pred_wrt_ptr);
300
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
301
    }
302
    OMP_THROW_EX();
303
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
304
305
306
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
307
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
308
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
309
310
311
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
312
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
313
314
315
316
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
317
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
318
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
319
320
321
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
322
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
323
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
324
325
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
326
327
  }

Guolin Ke's avatar
Guolin Ke committed
328
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
329
330
331
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

332
333
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
334
  }
335

336
  void LoadModelFromString(const char* model_str) {
337
338
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
339
340
  }

341
342
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
343
344
  }

345
  std::string DumpModel(int start_iteration, int num_iteration) {
346
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
347
  }
348

349
350
351
352
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
353
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
354
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
359
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
360
361
  }

362
  void ShuffleModels(int start_iter, int end_iter) {
363
    std::lock_guard<std::mutex> lock(mutex_);
364
    boosting_->ShuffleModels(start_iter, end_iter);
365
366
  }

wxchan's avatar
wxchan committed
367
368
369
370
371
372
373
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
374

wxchan's avatar
wxchan committed
375
376
377
378
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
379
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
380
381
382
383
384
385
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
386
387
388
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
389
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
390
391
392
393
394
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
395
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
396

Nikita Titov's avatar
Nikita Titov committed
397
 private:
wxchan's avatar
wxchan committed
398
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
399
  std::unique_ptr<Boosting> boosting_;
400
  std::unique_ptr<SingleRowPredictor> single_row_predictor_[PREDICTOR_TYPES];
401

Guolin Ke's avatar
Guolin Ke committed
402
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
403
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
404
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
405
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
406
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
407
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
408
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
409
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
410
411
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
412
413
};

414
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
415
416
417

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
418
419
420
421
422
423
424
425
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

426
427
428
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
429
430
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
431
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
432
433
434

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
435
 public:
Guolin Ke's avatar
Guolin Ke committed
436
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
437
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
438
439
440
441
442
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
443
444

 private:
Guolin Ke's avatar
Guolin Ke committed
445
446
447
448
449
450
451
452
453
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
454
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
455
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
456
457
}

Guolin Ke's avatar
Guolin Ke committed
458
int LGBM_DatasetCreateFromFile(const char* filename,
459
460
461
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
462
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
463
464
  auto param = Config::Str2Map(parameters);
  Config config;
465
466
467
468
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
469
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
470
  if (reference == nullptr) {
471
472
473
474
475
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
476
  } else {
477
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
478
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
479
  }
480
  API_END();
Guolin Ke's avatar
Guolin Ke committed
481
482
}

483

Guolin Ke's avatar
Guolin Ke committed
484
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
485
486
487
488
489
490
491
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
492
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
493
494
  auto param = Config::Str2Map(parameters);
  Config config;
495
496
497
498
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
499
  DatasetLoader loader(config, nullptr, 1, nullptr);
500
501
502
503
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
504
505
}

506

Guolin Ke's avatar
Guolin Ke committed
507
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
508
509
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
510
511
512
513
514
515
516
517
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
518
int LGBM_DatasetPushRows(DatasetHandle dataset,
519
520
521
522
523
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
524
525
526
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
527
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
528
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
529
  for (int i = 0; i < nrow; ++i) {
530
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
531
532
533
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
534
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
535
  }
536
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
537
538
539
540
541
542
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
543
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
544
545
546
547
548
549
550
551
552
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
553
554
555
556
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
557
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
558
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
559
  for (int i = 0; i < nrow; ++i) {
560
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
561
562
563
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
564
                          static_cast<data_size_t>(start_row + i), one_row);
565
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
566
  }
567
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
568
569
570
571
572
573
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
574
int LGBM_DatasetCreateFromMat(const void* data,
575
576
577
578
579
580
581
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
603
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
604
605
  auto param = Config::Str2Map(parameters);
  Config config;
606
607
608
609
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
610
  std::unique_ptr<Dataset> ret;
611
612
613
614
615
616
617
618
619
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
620

Guolin Ke's avatar
Guolin Ke committed
621
622
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
623
    Random rand(config.data_random_seed);
624
625
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
626
    sample_cnt = static_cast<int>(sample_indices.size());
627
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
628
    std::vector<std::vector<int>> sample_idx(ncol);
629
630
631

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
632
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
633
      auto idx = sample_indices[i];
634
635
636
637
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
638

639
640
641
642
643
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
644
        }
Guolin Ke's avatar
Guolin Ke committed
645
646
      }
    }
Guolin Ke's avatar
Guolin Ke committed
647
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
648
649
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
650
651
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
652
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
653
  } else {
654
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
655
    ret->CreateValid(
656
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
657
  }
658
659
660
661
662
663
664
665
666
667
668
669
670
671
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
672
673
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
674
  *out = ret.release();
675
  API_END();
676
677
}

Guolin Ke's avatar
Guolin Ke committed
678
int LGBM_DatasetCreateFromCSR(const void* indptr,
679
680
681
682
683
684
685
686
687
688
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
689
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
690
691
  auto param = Config::Str2Map(parameters);
  Config config;
692
693
694
695
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
696
  std::unique_ptr<Dataset> ret;
697
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
698
699
700
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
701
702
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
703
    auto sample_indices = rand.Sample(nrow, sample_cnt);
704
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
705
706
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
707
708
709
710
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
711
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
712
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
713
714
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
715
716
717
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
718
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
719
720
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
721
722
723
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
724
  } else {
725
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
726
    ret->CreateValid(
727
      reinterpret_cast<const Dataset*>(reference));
728
  }
729
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
730
  #pragma omp parallel for schedule(static)
731
  for (int i = 0; i < nindptr - 1; ++i) {
732
    OMP_LOOP_EX_BEGIN();
733
734
735
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
736
    OMP_LOOP_EX_END();
737
  }
738
  OMP_THROW_EX();
739
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
740
  *out = ret.release();
741
  API_END();
742
743
}

744
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
745
746
747
748
749
                                  int num_rows,
                                  int64_t num_col,
                                  const char* parameters,
                                  const DatasetHandle reference,
                                  DatasetHandle* out) {
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
  API_BEGIN();

  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);

  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
794

795
796
797
798
799
800
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
801
802
803
      const int tid = omp_get_thread_num();
      get_row_fun(i, threadBuffer);
      ret->PushOneRow(tid, i, threadBuffer);
804
805
806
807
808
809
810
811
812
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
813
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
814
815
816
817
818
819
820
821
822
823
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
824
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
825
826
  auto param = Config::Str2Map(parameters);
  Config config;
827
828
829
830
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
831
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
832
833
834
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
835
836
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
837
    auto sample_indices = rand.Sample(nrow, sample_cnt);
838
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
839
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
840
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
841
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
842
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
843
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
844
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
845
846
847
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
848
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
849
850
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
851
852
        }
      }
853
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
854
    }
855
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
856
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
857
858
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values).data(),
                                            Common::Vector2Ptr<int>(sample_idx).data(),
859
860
861
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
862
  } else {
863
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
864
    ret->CreateValid(
865
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
866
  }
867
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
868
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
869
  for (int i = 0; i < ncol_ptr - 1; ++i) {
870
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
871
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
872
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
873
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
874
875
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
876
877
878
879
880
881
882
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
883
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
884
    }
885
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
886
  }
887
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
888
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
889
  *out = ret.release();
890
  API_END();
Guolin Ke's avatar
Guolin Ke committed
891
892
}

Guolin Ke's avatar
Guolin Ke committed
893
int LGBM_DatasetGetSubset(
894
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
895
896
897
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
898
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
899
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
900
901
  auto param = Config::Str2Map(parameters);
  Config config;
902
903
904
905
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
906
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
907
  CHECK(num_used_row_indices > 0);
908
909
910
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
Guolin Ke's avatar
Guolin Ke committed
911
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
912
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
913
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
914
915
916
917
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
918
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
919
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
920
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
921
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
922
923
924
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
925
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
926
927
928
929
930
931
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
932
int LGBM_DatasetGetFeatureNames(
933
934
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
935
  int* num_feature_names) {
936
937
938
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
939
940
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
941
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
942
943
944
945
  }
  API_END();
}

946
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
947
int LGBM_DatasetFree(DatasetHandle handle) {
948
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
949
  delete reinterpret_cast<Dataset*>(handle);
950
  API_END();
951
952
}

Guolin Ke's avatar
Guolin Ke committed
953
int LGBM_DatasetSaveBinary(DatasetHandle handle,
954
                           const char* filename) {
955
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
956
957
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
958
  API_END();
959
960
}

961
962
963
964
965
966
967
968
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
969
int LGBM_DatasetSetField(DatasetHandle handle,
970
971
972
973
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
974
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
975
  auto dataset = reinterpret_cast<Dataset*>(handle);
976
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
977
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
978
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
979
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
980
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
981
982
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
983
  }
984
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
985
  API_END();
986
987
}

Guolin Ke's avatar
Guolin Ke committed
988
int LGBM_DatasetGetField(DatasetHandle handle,
989
990
991
992
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
993
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
994
  auto dataset = reinterpret_cast<Dataset*>(handle);
995
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
996
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
997
    *out_type = C_API_DTYPE_FLOAT32;
998
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
999
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1000
    *out_type = C_API_DTYPE_INT32;
1001
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1002
1003
1004
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
1005
  } else if (dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))) {
1006
1007
    *out_type = C_API_DTYPE_INT8;
    is_success = true;
1008
  }
1009
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
1010
  if (*out_ptr == nullptr) { *out_len = 0; }
1011
  API_END();
1012
1013
}

1014
1015
1016
1017
1018
1019
1020
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1021
int LGBM_DatasetGetNumData(DatasetHandle handle,
1022
                           int* out) {
1023
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1024
1025
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
1026
  API_END();
1027
1028
}

Guolin Ke's avatar
Guolin Ke committed
1029
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
1030
                              int* out) {
1031
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1032
1033
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1034
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1035
}
1036

1037
1038
1039
1040
1041
1042
1043
1044
1045
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
  target_d->addFeaturesFrom(source_d);
  API_END();
}

1046
1047
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1048
int LGBM_BoosterCreate(const DatasetHandle train_data,
1049
1050
                       const char* parameters,
                       BoosterHandle* out) {
1051
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1052
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1053
1054
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1055
  API_END();
1056
1057
}

Guolin Ke's avatar
Guolin Ke committed
1058
int LGBM_BoosterCreateFromModelfile(
1059
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1060
  int* out_num_iterations,
1061
  BoosterHandle* out) {
1062
  API_BEGIN();
wxchan's avatar
wxchan committed
1063
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1064
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1065
  *out = ret.release();
1066
  API_END();
1067
1068
}

Guolin Ke's avatar
Guolin Ke committed
1069
int LGBM_BoosterLoadModelFromString(
1070
1071
1072
1073
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1074
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1075
1076
1077
1078
1079
1080
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1081
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1082
int LGBM_BoosterFree(BoosterHandle handle) {
1083
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1084
  delete reinterpret_cast<Booster*>(handle);
1085
  API_END();
1086
1087
}

1088
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1089
1090
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1091
  ref_booster->ShuffleModels(start_iter, end_iter);
1092
1093
1094
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1095
int LGBM_BoosterMerge(BoosterHandle handle,
1096
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1097
1098
1099
1100
1101
1102
1103
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1104
int LGBM_BoosterAddValidData(BoosterHandle handle,
1105
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1106
1107
1108
1109
1110
1111
1112
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1113
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1114
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1115
1116
1117
1118
1119
1120
1121
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1122
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1123
1124
1125
1126
1127
1128
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1129
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1130
1131
1132
1133
1134
1135
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1136
1137
1138
1139
1140
1141
1142
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1143
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1144
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1145
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1146
1147
1148
1149
1150
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1151
  API_END();
1152
1153
}

Guolin Ke's avatar
Guolin Ke committed
1154
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1155
1156
1157
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1158
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1159
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1160
  #ifdef SCORE_T_USE_DOUBLE
1161
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1162
  #else
1163
1164
1165
1166
1167
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1168
  #endif
1169
  API_END();
1170
1171
}

Guolin Ke's avatar
Guolin Ke committed
1172
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1173
1174
1175
1176
1177
1178
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1179
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1180
1181
1182
1183
1184
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1185

1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1200
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1201
1202
1203
1204
1205
1206
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1207
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1208
1209
1210
1211
1212
1213
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1214
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1215
1216
1217
1218
1219
1220
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1221
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1222
1223
1224
1225
1226
1227
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1228
int LGBM_BoosterGetEval(BoosterHandle handle,
1229
1230
1231
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1232
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1233
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1234
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1235
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1236
  *out_len = static_cast<int>(result_buf.size());
1237
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1238
    (out_results)[i] = static_cast<double>(result_buf[i]);
1239
  }
1240
  API_END();
1241
1242
}

Guolin Ke's avatar
Guolin Ke committed
1243
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1244
1245
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1246
1247
1248
1249
1250
1251
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1252
int LGBM_BoosterGetPredict(BoosterHandle handle,
1253
1254
1255
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1256
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1257
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1258
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1259
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1260
1261
}

Guolin Ke's avatar
Guolin Ke committed
1262
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1263
1264
1265
1266
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1267
                               const char* parameter,
1268
                               const char* result_filename) {
1269
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1270
1271
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1272
1273
1274
1275
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1276
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1277
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1278
                       config, result_filename);
1279
  API_END();
1280
1281
}

Guolin Ke's avatar
Guolin Ke committed
1282
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1283
1284
1285
1286
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1287
1288
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1289
1290
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1291
1292
1293
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1294
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int predict_type,
                              int num_iteration,
1305
                              const char* parameter,
1306
1307
                              int64_t* out_len,
                              double* out_result) {
1308
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1309
1310
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1311
1312
1313
1314
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1315
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1316
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1317
  int nrow = static_cast<int>(nindptr - 1);
cbecker's avatar
cbecker committed
1318
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1319
                       config, out_result, out_len);
1320
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1321
}
1322

1323
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
                                       const void* indptr,
                                       int indptr_type,
                                       const int32_t* indices,
                                       const void* data,
                                       int data_type,
                                       int64_t nindptr,
                                       int64_t nelem,
                                       int64_t,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1337
1338
1339
1340
1341
1342
1343
1344
1345
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
1346
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len);
1347
1348
1349
1350
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1351
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1362
                              const char* parameter,
1363
1364
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1365
1366
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1367
1368
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1379
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1380
1381
1382
1383
1384
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1385
1386
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1387
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1388
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1389
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1390
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1391
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1392
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1393
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1394
1395
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1396
1397
    return one_row;
  };
Guolin Ke's avatar
Guolin Ke committed
1398
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config,
cbecker's avatar
cbecker committed
1399
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1400
1401
1402
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1403
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1404
1405
1406
1407
1408
1409
1410
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1411
                              const char* parameter,
1412
1413
                              int64_t* out_len,
                              double* out_result) {
1414
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1415
1416
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1417
1418
1419
1420
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1421
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1422
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
cbecker's avatar
cbecker committed
1423
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1424
                       config, out_result, out_len);
1425
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1426
}
1427

1428
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
1429
1430
1431
1432
1433
1434
1435
1436
1437
                                       const void* data,
                                       int data_type,
                                       int32_t ncol,
                                       int is_row_major,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1438
1439
1440
1441
1442
1443
1444
1445
1446
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
1447
  ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len);
1448
1449
1450
1451
  API_END();
}


1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
1471
  ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, config, out_result, out_len);
1472
1473
1474
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1475
int LGBM_BoosterSaveModel(BoosterHandle handle,
1476
                          int start_iteration,
1477
1478
                          int num_iteration,
                          const char* filename) {
1479
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1480
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1481
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1482
1483
1484
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1485
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1486
                                  int start_iteration,
1487
                                  int num_iteration,
1488
                                  int64_t buffer_len,
1489
                                  int64_t* out_len,
1490
                                  char* out_str) {
1491
1492
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1493
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1494
  *out_len = static_cast<int64_t>(model.size()) + 1;
1495
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1496
    std::memcpy(out_str, model.c_str(), *out_len);
1497
1498
1499
1500
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1501
int LGBM_BoosterDumpModel(BoosterHandle handle,
1502
                          int start_iteration,
1503
                          int num_iteration,
1504
1505
                          int64_t buffer_len,
                          int64_t* out_len,
1506
                          char* out_str) {
wxchan's avatar
wxchan committed
1507
1508
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1509
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1510
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1511
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1512
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1513
  }
1514
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1515
}
1516

Guolin Ke's avatar
Guolin Ke committed
1517
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1518
1519
1520
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1521
1522
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1523
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1524
1525
1526
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1527
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1528
1529
1530
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1531
1532
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1533
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1534
1535
1536
  API_END();
}

1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1550
1551
1552
1553
1554
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1555
  Config config;
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1572
1573
1574
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1575
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1576
  if (num_machines > 1) {
1577
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1578
1579
1580
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1581

Guolin Ke's avatar
Guolin Ke committed
1582
// ---- start of some help functions
1583
1584
1585

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1586
  if (data_type == C_API_DTYPE_FLOAT32) {
1587
1588
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1589
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1590
        std::vector<double> ret(num_col);
1591
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1592
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1593
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1594
1595
1596
1597
        }
        return ret;
      };
    } else {
1598
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1599
        std::vector<double> ret(num_col);
1600
        for (int i = 0; i < num_col; ++i) {
1601
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1602
1603
1604
1605
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1606
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1607
1608
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1609
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1610
        std::vector<double> ret(num_col);
1611
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1612
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1613
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1614
1615
1616
1617
        }
        return ret;
      };
    } else {
1618
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1619
        std::vector<double> ret(num_col);
1620
        for (int i = 0; i < num_col; ++i) {
1621
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1622
1623
1624
1625
1626
        }
        return ret;
      };
    }
  }
1627
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1628
1629
1630
1631
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1632
1633
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1634
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1635
1636
1637
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1638
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1639
          ret.emplace_back(i, raw_values[i]);
1640
        }
Guolin Ke's avatar
Guolin Ke committed
1641
1642
1643
      }
      return ret;
    };
1644
  }
Guolin Ke's avatar
Guolin Ke committed
1645
  return nullptr;
1646
1647
}

1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1664
std::function<std::vector<std::pair<int, double>>(int idx)>
1665
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1666
  if (data_type == C_API_DTYPE_FLOAT32) {
1667
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1668
    if (indptr_type == C_API_DTYPE_INT32) {
1669
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1670
      return [=] (int idx) {
1671
1672
1673
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1674
1675
1676
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1677
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1678
          ret.emplace_back(indices[i], data_ptr[i]);
1679
1680
1681
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1682
    } else if (indptr_type == C_API_DTYPE_INT64) {
1683
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1684
      return [=] (int idx) {
1685
1686
1687
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1688
1689
1690
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1691
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1692
          ret.emplace_back(indices[i], data_ptr[i]);
1693
1694
1695
1696
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1697
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1698
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1699
    if (indptr_type == C_API_DTYPE_INT32) {
1700
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1701
      return [=] (int idx) {
1702
1703
1704
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1705
1706
1707
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1708
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1709
          ret.emplace_back(indices[i], data_ptr[i]);
1710
1711
1712
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1713
    } else if (indptr_type == C_API_DTYPE_INT64) {
1714
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1715
      return [=] (int idx) {
1716
1717
1718
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1719
1720
1721
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1722
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1723
          ret.emplace_back(indices[i], data_ptr[i]);
1724
1725
1726
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1727
1728
    }
  }
1729
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1730
1731
}

Guolin Ke's avatar
Guolin Ke committed
1732
std::function<std::pair<int, double>(int idx)>
1733
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1734
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1735
  if (data_type == C_API_DTYPE_FLOAT32) {
1736
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1737
    if (col_ptr_type == C_API_DTYPE_INT32) {
1738
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1739
1740
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1741
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1742
1743
1744
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1745
        }
Guolin Ke's avatar
Guolin Ke committed
1746
1747
1748
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1749
      };
Guolin Ke's avatar
Guolin Ke committed
1750
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1751
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1752
1753
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1754
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1755
1756
1757
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1758
        }
Guolin Ke's avatar
Guolin Ke committed
1759
1760
1761
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1762
      };
Guolin Ke's avatar
Guolin Ke committed
1763
    }
Guolin Ke's avatar
Guolin Ke committed
1764
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1765
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1766
    if (col_ptr_type == C_API_DTYPE_INT32) {
1767
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1768
1769
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1770
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1771
1772
1773
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1774
        }
Guolin Ke's avatar
Guolin Ke committed
1775
1776
1777
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1778
      };
Guolin Ke's avatar
Guolin Ke committed
1779
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1780
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1781
1782
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1783
      return [=] (int bias) {
Guolin Ke's avatar
Guolin Ke committed
1784
1785
1786
        int64_t i = static_cast<int64_t>(start + bias);
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1787
        }
Guolin Ke's avatar
Guolin Ke committed
1788
1789
1790
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1791
      };
Guolin Ke's avatar
Guolin Ke committed
1792
1793
    }
  }
1794
  throw std::runtime_error("Unknown data type in CSC matrix");
1795
1796
}

Guolin Ke's avatar
Guolin Ke committed
1797
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1798
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1799
1800
1801
1802
1803
1804
1805
1806
1807
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1808
    }
Guolin Ke's avatar
Guolin Ke committed
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1826
    }
Guolin Ke's avatar
Guolin Ke committed
1827
1828
1829
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1830
  }
Guolin Ke's avatar
Guolin Ke committed
1831
}