c_api.cpp 66.7 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
#include <LightGBM/c_api.h>
Guolin Ke's avatar
Guolin Ke committed
6

Guolin Ke's avatar
Guolin Ke committed
7
8
#include <LightGBM/boosting.h>
#include <LightGBM/config.h>
9
10
11
#include <LightGBM/dataset.h>
#include <LightGBM/dataset_loader.h>
#include <LightGBM/metric.h>
12
#include <LightGBM/network.h>
13
14
15
16
17
18
19
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/random.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
20
21

#include <string>
22
23
#include <cstdio>
#include <functional>
Guolin Ke's avatar
Guolin Ke committed
24
#include <memory>
wxchan's avatar
wxchan committed
25
#include <mutex>
26
27
#include <stdexcept>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
28

29
#include "application/predictor.hpp"
Guolin Ke's avatar
Guolin Ke committed
30

Guolin Ke's avatar
Guolin Ke committed
31
32
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
inline int LGBM_APIHandleException(const std::exception& ex) {
  LGBM_SetLastError(ex.what());
  return -1;
}
inline int LGBM_APIHandleException(const std::string& ex) {
  LGBM_SetLastError(ex.c_str());
  return -1;
}

#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0;

49
50
51
52
53
54
55
56
const int PREDICTOR_TYPES = 4;

// Single row predictor to abstract away caching logic
class SingleRowPredictor {
 public:
  PredictFunction predict_function;
  int64_t num_pred_in_one_row;

Guolin Ke's avatar
Guolin Ke committed
57
  SingleRowPredictor(int predict_type, Boosting* boosting, const Config& config, int iter) {
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    bool is_predict_leaf = false;
    bool is_raw_score = false;
    bool predict_contrib = false;
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
      predict_contrib = true;
    } else {
      is_raw_score = false;
    }
    early_stop_ = config.pred_early_stop;
    early_stop_freq_ = config.pred_early_stop_freq;
    early_stop_margin_ = config.pred_early_stop_margin;
    iter_ = iter;
Guolin Ke's avatar
Guolin Ke committed
74
    predictor_.reset(new Predictor(boosting, iter_, is_raw_score, is_predict_leaf, predict_contrib,
75
                                   early_stop_, early_stop_freq_, early_stop_margin_));
Guolin Ke's avatar
Guolin Ke committed
76
    num_pred_in_one_row = boosting->NumPredictOneRow(iter_, is_predict_leaf, predict_contrib);
77
    predict_function = predictor_->GetPredictFunction();
Guolin Ke's avatar
Guolin Ke committed
78
    num_total_model_ = boosting->NumberOfTotalModel();
79
80
  }
  ~SingleRowPredictor() {}
Guolin Ke's avatar
Guolin Ke committed
81
  bool IsPredictorEqual(const Config& config, int iter, Boosting* boosting) {
82
83
84
85
    return early_stop_ != config.pred_early_stop ||
      early_stop_freq_ != config.pred_early_stop_freq ||
      early_stop_margin_ != config.pred_early_stop_margin ||
      iter_ != iter ||
Guolin Ke's avatar
Guolin Ke committed
86
      num_total_model_ != boosting->NumberOfTotalModel();
87
  }
Guolin Ke's avatar
Guolin Ke committed
88

89
90
91
92
93
94
95
96
97
 private:
  std::unique_ptr<Predictor> predictor_;
  bool early_stop_;
  int early_stop_freq_;
  double early_stop_margin_;
  int iter_;
  int num_total_model_;
};

Guolin Ke's avatar
Guolin Ke committed
98
class Booster {
Nikita Titov's avatar
Nikita Titov committed
99
 public:
Guolin Ke's avatar
Guolin Ke committed
100
  explicit Booster(const char* filename) {
101
    boosting_.reset(Boosting::CreateBoosting("gbdt", filename));
102
103
  }

Guolin Ke's avatar
Guolin Ke committed
104
  Booster(const Dataset* train_data,
105
          const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
106
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
107
    config_.Set(param);
108
109
110
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
111
    // create boosting
Guolin Ke's avatar
Guolin Ke committed
112
    if (config_.input_model.size() > 0) {
113
114
      Log::Warning("Continued train from model is not supported for c_api,\n"
                   "please use continued train with input score");
Guolin Ke's avatar
Guolin Ke committed
115
    }
Guolin Ke's avatar
Guolin Ke committed
116

Guolin Ke's avatar
Guolin Ke committed
117
    boosting_.reset(Boosting::CreateBoosting(config_.boosting, nullptr));
Guolin Ke's avatar
Guolin Ke committed
118

119
120
    train_data_ = train_data;
    CreateObjectiveAndMetrics();
Guolin Ke's avatar
Guolin Ke committed
121
    // initialize the boosting
Guolin Ke's avatar
Guolin Ke committed
122
    if (config_.tree_learner == std::string("feature")) {
123
      Log::Fatal("Do not support feature parallel in c api");
124
    }
Guolin Ke's avatar
Guolin Ke committed
125
    if (Network::num_machines() == 1 && config_.tree_learner != std::string("serial")) {
126
      Log::Warning("Only find one worker, will switch to serial tree learner");
Guolin Ke's avatar
Guolin Ke committed
127
      config_.tree_learner = "serial";
128
    }
Guolin Ke's avatar
Guolin Ke committed
129
    boosting_->Init(&config_, train_data_, objective_fun_.get(),
130
                    Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
131
132
133
134
135
  }

  void MergeFrom(const Booster* other) {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->MergeFrom(other->boosting_.get());
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
  }

  ~Booster() {
  }
140

141
  void CreateObjectiveAndMetrics() {
Guolin Ke's avatar
Guolin Ke committed
142
    // create objective function
Guolin Ke's avatar
Guolin Ke committed
143
144
    objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                    config_));
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
152
153
154
    if (objective_fun_ == nullptr) {
      Log::Warning("Using self-defined objective function");
    }
    // initialize the objective function
    if (objective_fun_ != nullptr) {
      objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
    }

    // create training metric
    train_metric_.clear();
Guolin Ke's avatar
Guolin Ke committed
155
    for (auto metric_type : config_.metric) {
Guolin Ke's avatar
Guolin Ke committed
156
      auto metric = std::unique_ptr<Metric>(
Guolin Ke's avatar
Guolin Ke committed
157
        Metric::CreateMetric(metric_type, config_));
Guolin Ke's avatar
Guolin Ke committed
158
159
160
161
162
      if (metric == nullptr) { continue; }
      metric->Init(train_data_->metadata(), train_data_->num_data());
      train_metric_.push_back(std::move(metric));
    }
    train_metric_.shrink_to_fit();
163
164
165
166
167
168
169
170
171
172
173
  }

  void ResetTrainingData(const Dataset* train_data) {
    if (train_data != train_data_) {
      std::lock_guard<std::mutex> lock(mutex_);
      train_data_ = train_data;
      CreateObjectiveAndMetrics();
      // reset the boosting
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
    }
wxchan's avatar
wxchan committed
174
175
176
  }

  void ResetConfig(const char* parameters) {
Guolin Ke's avatar
Guolin Ke committed
177
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
178
    auto param = Config::Str2Map(parameters);
wxchan's avatar
wxchan committed
179
    if (param.count("num_class")) {
180
      Log::Fatal("Cannot change num_class during training");
wxchan's avatar
wxchan committed
181
    }
Guolin Ke's avatar
Guolin Ke committed
182
183
    if (param.count("boosting")) {
      Log::Fatal("Cannot change boosting during training");
wxchan's avatar
wxchan committed
184
    }
Guolin Ke's avatar
Guolin Ke committed
185
    if (param.count("metric")) {
186
      Log::Fatal("Cannot change metric during training");
Guolin Ke's avatar
Guolin Ke committed
187
    }
Guolin Ke's avatar
Guolin Ke committed
188
189

    config_.Set(param);
190
191
192
    if (config_.num_threads > 0) {
      omp_set_num_threads(config_.num_threads);
    }
Guolin Ke's avatar
Guolin Ke committed
193
194
195

    if (param.count("objective")) {
      // create objective function
Guolin Ke's avatar
Guolin Ke committed
196
197
      objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective,
                                                                      config_));
Guolin Ke's avatar
Guolin Ke committed
198
199
200
201
202
203
204
      if (objective_fun_ == nullptr) {
        Log::Warning("Using self-defined objective function");
      }
      // initialize the objective function
      if (objective_fun_ != nullptr) {
        objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
      }
205
206
      boosting_->ResetTrainingData(train_data_,
                                   objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
wxchan's avatar
wxchan committed
207
    }
Guolin Ke's avatar
Guolin Ke committed
208

Guolin Ke's avatar
Guolin Ke committed
209
    boosting_->ResetConfig(&config_);
wxchan's avatar
wxchan committed
210
211
212
213
214
  }

  void AddValidData(const Dataset* valid_data) {
    std::lock_guard<std::mutex> lock(mutex_);
    valid_metrics_.emplace_back();
Guolin Ke's avatar
Guolin Ke committed
215
216
    for (auto metric_type : config_.metric) {
      auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_));
wxchan's avatar
wxchan committed
217
218
219
220
221
222
      if (metric == nullptr) { continue; }
      metric->Init(valid_data->metadata(), valid_data->num_data());
      valid_metrics_.back().push_back(std::move(metric));
    }
    valid_metrics_.back().shrink_to_fit();
    boosting_->AddValidDataset(valid_data,
223
                               Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
wxchan's avatar
wxchan committed
224
  }
Guolin Ke's avatar
Guolin Ke committed
225

226
  bool TrainOneIter() {
wxchan's avatar
wxchan committed
227
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
228
    return boosting_->TrainOneIter(nullptr, nullptr);
229
230
  }

Guolin Ke's avatar
Guolin Ke committed
231
232
233
234
235
236
237
238
239
240
241
  void Refit(const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
    std::lock_guard<std::mutex> lock(mutex_);
    std::vector<std::vector<int32_t>> v_leaf_preds(nrow, std::vector<int32_t>(ncol, 0));
    for (int i = 0; i < nrow; ++i) {
      for (int j = 0; j < ncol; ++j) {
        v_leaf_preds[i][j] = leaf_preds[i * ncol + j];
      }
    }
    boosting_->RefitTree(v_leaf_preds);
  }

242
  bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
wxchan's avatar
wxchan committed
243
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
244
    return boosting_->TrainOneIter(gradients, hessians);
245
246
  }

wxchan's avatar
wxchan committed
247
248
249
250
251
  void RollbackOneIter() {
    std::lock_guard<std::mutex> lock(mutex_);
    boosting_->RollbackOneIter();
  }

252
  void PredictSingleRow(int num_iteration, int predict_type, int ncol,
253
254
255
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
               const Config& config,
               double* out_result, int64_t* out_len) {
256
257
258
    if (ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1);
    }
259
    std::lock_guard<std::mutex> lock(mutex_);
260
    if (single_row_predictor_[predict_type].get() == nullptr ||
Guolin Ke's avatar
Guolin Ke committed
261
262
        !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
      single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
263
                                                                       config, num_iteration));
264
265
266
    }
    auto one_row = get_row_fun(0);
    auto pred_wrt_ptr = out_result;
267
    single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
268

269
    *out_len = single_row_predictor_[predict_type]->num_pred_in_one_row;
270
271
272
  }


273
  void Predict(int num_iteration, int predict_type, int nrow, int ncol,
Guolin Ke's avatar
Guolin Ke committed
274
               std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
275
               const Config& config,
Guolin Ke's avatar
Guolin Ke committed
276
               double* out_result, int64_t* out_len) {
277
278
279
    if (ncol != boosting_->MaxFeatureIdx() + 1) {
      Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1);
    }
wxchan's avatar
wxchan committed
280
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
281
282
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
283
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
284
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
Guolin Ke's avatar
Guolin Ke committed
285
      is_predict_leaf = true;
Guolin Ke's avatar
Guolin Ke committed
286
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
Guolin Ke's avatar
Guolin Ke committed
287
      is_raw_score = true;
288
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
289
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
290
291
    } else {
      is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
292
    }
Guolin Ke's avatar
Guolin Ke committed
293

Guolin Ke's avatar
Guolin Ke committed
294
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
295
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
296
    int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, predict_contrib);
Guolin Ke's avatar
Guolin Ke committed
297
    auto pred_fun = predictor.GetPredictFunction();
298
299
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
300
    for (int i = 0; i < nrow; ++i) {
301
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
302
      auto one_row = get_row_fun(i);
Tony-Y's avatar
Tony-Y committed
303
      auto pred_wrt_ptr = out_result + static_cast<size_t>(num_pred_in_one_row) * i;
Guolin Ke's avatar
Guolin Ke committed
304
      pred_fun(one_row, pred_wrt_ptr);
305
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
306
    }
307
    OMP_THROW_EX();
308
    *out_len = num_pred_in_one_row * nrow;
Guolin Ke's avatar
Guolin Ke committed
309
310
311
  }

  void Predict(int num_iteration, int predict_type, const char* data_filename,
Guolin Ke's avatar
Guolin Ke committed
312
               int data_has_header, const Config& config,
cbecker's avatar
cbecker committed
313
               const char* result_filename) {
Guolin Ke's avatar
Guolin Ke committed
314
315
316
    std::lock_guard<std::mutex> lock(mutex_);
    bool is_predict_leaf = false;
    bool is_raw_score = false;
Guolin Ke's avatar
Guolin Ke committed
317
    bool predict_contrib = false;
Guolin Ke's avatar
Guolin Ke committed
318
319
320
321
    if (predict_type == C_API_PREDICT_LEAF_INDEX) {
      is_predict_leaf = true;
    } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
      is_raw_score = true;
322
    } else if (predict_type == C_API_PREDICT_CONTRIB) {
Guolin Ke's avatar
Guolin Ke committed
323
      predict_contrib = true;
Guolin Ke's avatar
Guolin Ke committed
324
325
326
    } else {
      is_raw_score = false;
    }
Guolin Ke's avatar
Guolin Ke committed
327
    Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
328
                        config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
Guolin Ke's avatar
Guolin Ke committed
329
330
    bool bool_data_has_header = data_has_header > 0 ? true : false;
    predictor.Predict(data_filename, result_filename, bool_data_has_header);
Guolin Ke's avatar
Guolin Ke committed
331
332
  }

Guolin Ke's avatar
Guolin Ke committed
333
  void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
wxchan's avatar
wxchan committed
334
335
336
    boosting_->GetPredictAt(data_idx, out_result, out_len);
  }

337
338
  void SaveModelToFile(int start_iteration, int num_iteration, const char* filename) {
    boosting_->SaveModelToFile(start_iteration, num_iteration, filename);
Guolin Ke's avatar
Guolin Ke committed
339
  }
340

341
  void LoadModelFromString(const char* model_str) {
342
343
    size_t len = std::strlen(model_str);
    boosting_->LoadModelFromString(model_str, len);
344
345
  }

346
347
  std::string SaveModelToString(int start_iteration, int num_iteration) {
    return boosting_->SaveModelToString(start_iteration, num_iteration);
348
349
  }

350
  std::string DumpModel(int start_iteration, int num_iteration) {
351
    return boosting_->DumpModel(start_iteration, num_iteration);
wxchan's avatar
wxchan committed
352
  }
353

354
355
356
357
  std::vector<double> FeatureImportance(int num_iteration, int importance_type) {
    return boosting_->FeatureImportance(num_iteration, importance_type);
  }

Guolin Ke's avatar
Guolin Ke committed
358
  double GetLeafValue(int tree_idx, int leaf_idx) const {
Guolin Ke's avatar
Guolin Ke committed
359
    return dynamic_cast<GBDTBase*>(boosting_.get())->GetLeafValue(tree_idx, leaf_idx);
Guolin Ke's avatar
Guolin Ke committed
360
361
362
363
  }

  void SetLeafValue(int tree_idx, int leaf_idx, double val) {
    std::lock_guard<std::mutex> lock(mutex_);
Guolin Ke's avatar
Guolin Ke committed
364
    dynamic_cast<GBDTBase*>(boosting_.get())->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
365
366
  }

367
  void ShuffleModels(int start_iter, int end_iter) {
368
    std::lock_guard<std::mutex> lock(mutex_);
369
    boosting_->ShuffleModels(start_iter, end_iter);
370
371
  }

wxchan's avatar
wxchan committed
372
373
374
375
376
377
378
  int GetEvalCounts() const {
    int ret = 0;
    for (const auto& metric : train_metric_) {
      ret += static_cast<int>(metric->GetName().size());
    }
    return ret;
  }
379

wxchan's avatar
wxchan committed
380
381
382
383
  int GetEvalNames(char** out_strs) const {
    int idx = 0;
    for (const auto& metric : train_metric_) {
      for (const auto& name : metric->GetName()) {
Guolin Ke's avatar
Guolin Ke committed
384
        std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
385
386
387
388
389
390
        ++idx;
      }
    }
    return idx;
  }

wxchan's avatar
wxchan committed
391
392
393
  int GetFeatureNames(char** out_strs) const {
    int idx = 0;
    for (const auto& name : boosting_->FeatureNames()) {
Guolin Ke's avatar
Guolin Ke committed
394
      std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
wxchan's avatar
wxchan committed
395
396
397
398
399
      ++idx;
    }
    return idx;
  }

wxchan's avatar
wxchan committed
400
  const Boosting* GetBoosting() const { return boosting_.get(); }
Guolin Ke's avatar
Guolin Ke committed
401

Nikita Titov's avatar
Nikita Titov committed
402
 private:
wxchan's avatar
wxchan committed
403
  const Dataset* train_data_;
Guolin Ke's avatar
Guolin Ke committed
404
  std::unique_ptr<Boosting> boosting_;
405
  std::unique_ptr<SingleRowPredictor> single_row_predictor_[PREDICTOR_TYPES];
406

Guolin Ke's avatar
Guolin Ke committed
407
  /*! \brief All configs */
Guolin Ke's avatar
Guolin Ke committed
408
  Config config_;
Guolin Ke's avatar
Guolin Ke committed
409
  /*! \brief Metric for training data */
Guolin Ke's avatar
Guolin Ke committed
410
  std::vector<std::unique_ptr<Metric>> train_metric_;
Guolin Ke's avatar
Guolin Ke committed
411
  /*! \brief Metrics for validation data */
Guolin Ke's avatar
Guolin Ke committed
412
  std::vector<std::vector<std::unique_ptr<Metric>>> valid_metrics_;
Guolin Ke's avatar
Guolin Ke committed
413
  /*! \brief Training objective function */
Guolin Ke's avatar
Guolin Ke committed
414
  std::unique_ptr<ObjectiveFunction> objective_fun_;
wxchan's avatar
wxchan committed
415
416
  /*! \brief mutex for threading safe call */
  std::mutex mutex_;
Guolin Ke's avatar
Guolin Ke committed
417
418
};

419
}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
420
421
422

using namespace LightGBM;

Guolin Ke's avatar
Guolin Ke committed
423
424
425
426
427
428
429
430
// some help functions used to convert data

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major);

431
432
433
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type);

Guolin Ke's avatar
Guolin Ke committed
434
435
std::function<std::vector<std::pair<int, double>>(int idx)>
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
436
                   const void* data, int data_type, int64_t nindptr, int64_t nelem);
Guolin Ke's avatar
Guolin Ke committed
437
438
439

// Row iterator of on column for CSC matrix
class CSC_RowIterator {
Nikita Titov's avatar
Nikita Titov committed
440
 public:
Guolin Ke's avatar
Guolin Ke committed
441
  CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
442
                  const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx);
Guolin Ke's avatar
Guolin Ke committed
443
444
445
446
447
  ~CSC_RowIterator() {}
  // return value at idx, only can access by ascent order
  double Get(int idx);
  // return next non-zero pair, if index < 0, means no more data
  std::pair<int, double> NextNonZero();
Nikita Titov's avatar
Nikita Titov committed
448
449

 private:
Guolin Ke's avatar
Guolin Ke committed
450
451
452
453
454
455
456
457
458
  int nonzero_idx_ = 0;
  int cur_idx_ = -1;
  double cur_val_ = 0.0f;
  bool is_end_ = false;
  std::function<std::pair<int, double>(int idx)> iter_fun_;
};

// start of c_api functions

Guolin Ke's avatar
Guolin Ke committed
459
const char* LGBM_GetLastError() {
wxchan's avatar
wxchan committed
460
  return LastErrorMsg();
Guolin Ke's avatar
Guolin Ke committed
461
462
}

Guolin Ke's avatar
Guolin Ke committed
463
int LGBM_DatasetCreateFromFile(const char* filename,
464
465
466
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
467
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
468
469
  auto param = Config::Str2Map(parameters);
  Config config;
470
471
472
473
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
474
  DatasetLoader loader(config, nullptr, 1, filename);
Guolin Ke's avatar
Guolin Ke committed
475
  if (reference == nullptr) {
476
477
478
479
480
    if (Network::num_machines() == 1) {
      *out = loader.LoadFromFile(filename, "");
    } else {
      *out = loader.LoadFromFile(filename, "", Network::rank(), Network::num_machines());
    }
Guolin Ke's avatar
Guolin Ke committed
481
  } else {
482
    *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
483
                                                    reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
484
  }
485
  API_END();
Guolin Ke's avatar
Guolin Ke committed
486
487
}

488

Guolin Ke's avatar
Guolin Ke committed
489
int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
490
491
492
493
494
495
496
                                        int** sample_indices,
                                        int32_t ncol,
                                        const int* num_per_col,
                                        int32_t num_sample_row,
                                        int32_t num_total_row,
                                        const char* parameters,
                                        DatasetHandle* out) {
497
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
498
499
  auto param = Config::Str2Map(parameters);
  Config config;
500
501
502
503
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
504
  DatasetLoader loader(config, nullptr, 1, nullptr);
505
506
507
508
  *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
                                       num_sample_row,
                                       static_cast<data_size_t>(num_total_row));
  API_END();
Guolin Ke's avatar
Guolin Ke committed
509
510
}

511

Guolin Ke's avatar
Guolin Ke committed
512
int LGBM_DatasetCreateByReference(const DatasetHandle reference,
513
514
                                  int64_t num_total_row,
                                  DatasetHandle* out) {
Guolin Ke's avatar
Guolin Ke committed
515
516
517
518
519
520
521
522
  API_BEGIN();
  std::unique_ptr<Dataset> ret;
  ret.reset(new Dataset(static_cast<data_size_t>(num_total_row)));
  ret->CreateValid(reinterpret_cast<const Dataset*>(reference));
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
523
int LGBM_DatasetPushRows(DatasetHandle dataset,
524
525
526
527
528
                         const void* data,
                         int data_type,
                         int32_t nrow,
                         int32_t ncol,
                         int32_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
529
530
531
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1);
532
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
533
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
534
  for (int i = 0; i < nrow; ++i) {
535
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
536
537
538
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid, start_row + i, one_row);
539
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
540
  }
541
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
542
543
544
545
546
547
  if (start_row + nrow == p_dataset->num_data()) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
548
int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset,
549
550
551
552
553
554
555
556
557
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t,
                              int64_t start_row) {
Guolin Ke's avatar
Guolin Ke committed
558
559
560
561
  API_BEGIN();
  auto p_dataset = reinterpret_cast<Dataset*>(dataset);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
562
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
563
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
564
  for (int i = 0; i < nrow; ++i) {
565
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
566
567
568
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    p_dataset->PushOneRow(tid,
569
                          static_cast<data_size_t>(start_row + i), one_row);
570
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
571
  }
572
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
573
574
575
576
577
578
  if (start_row + nrow == static_cast<int64_t>(p_dataset->num_data())) {
    p_dataset->FinishLoad();
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
579
int LGBM_DatasetCreateFromMat(const void* data,
580
581
582
583
584
585
586
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
  return LGBM_DatasetCreateFromMats(1,
                                    &data,
                                    data_type,
                                    &nrow,
                                    ncol,
                                    is_row_major,
                                    parameters,
                                    reference,
                                    out);
}


int LGBM_DatasetCreateFromMats(int32_t nmat,
                               const void** data,
                               int data_type,
                               int32_t* nrow,
                               int32_t ncol,
                               int is_row_major,
                               const char* parameters,
                               const DatasetHandle reference,
                               DatasetHandle* out) {
608
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
609
610
  auto param = Config::Str2Map(parameters);
  Config config;
611
612
613
614
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
615
  std::unique_ptr<Dataset> ret;
616
617
618
619
620
621
622
623
624
  int32_t total_nrow = 0;
  for (int j = 0; j < nmat; ++j) {
    total_nrow += nrow[j];
  }

  std::vector<std::function<std::vector<double>(int row_idx)>> get_row_fun;
  for (int j = 0; j < nmat; ++j) {
    get_row_fun.push_back(RowFunctionFromDenseMatric(data[j], nrow[j], ncol, data_type, is_row_major));
  }
625

Guolin Ke's avatar
Guolin Ke committed
626
627
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
628
    Random rand(config.data_random_seed);
629
630
    int sample_cnt = static_cast<int>(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(total_nrow, sample_cnt);
631
    sample_cnt = static_cast<int>(sample_indices.size());
632
    std::vector<std::vector<double>> sample_values(ncol);
Guolin Ke's avatar
Guolin Ke committed
633
    std::vector<std::vector<int>> sample_idx(ncol);
634
635
636

    int offset = 0;
    int j = 0;
Guolin Ke's avatar
Guolin Ke committed
637
    for (size_t i = 0; i < sample_indices.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
638
      auto idx = sample_indices[i];
639
640
641
642
      while ((idx - offset) >= nrow[j]) {
        offset += nrow[j];
        ++j;
      }
643

644
645
646
647
648
      auto row = get_row_fun[j](static_cast<int>(idx - offset));
      for (size_t k = 0; k < row.size(); ++k) {
        if (std::fabs(row[k]) > kZeroThreshold || std::isnan(row[k])) {
          sample_values[k].emplace_back(row[k]);
          sample_idx[k].emplace_back(static_cast<int>(i));
Guolin Ke's avatar
Guolin Ke committed
649
        }
Guolin Ke's avatar
Guolin Ke committed
650
651
      }
    }
Guolin Ke's avatar
Guolin Ke committed
652
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
653
654
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
655
                                            ncol,
656
                                            Common::VectorSize<double>(sample_values).data(),
657
                                            sample_cnt, total_nrow));
Guolin Ke's avatar
Guolin Ke committed
658
  } else {
659
    ret.reset(new Dataset(total_nrow));
Guolin Ke's avatar
Guolin Ke committed
660
    ret->CreateValid(
661
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
662
  }
663
664
665
666
667
668
669
670
671
672
673
674
675
676
  int32_t start_row = 0;
  for (int j = 0; j < nmat; ++j) {
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < nrow[j]; ++i) {
      OMP_LOOP_EX_BEGIN();
      const int tid = omp_get_thread_num();
      auto one_row = get_row_fun[j](i);
      ret->PushOneRow(tid, start_row + i, one_row);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();

    start_row += nrow[j];
Guolin Ke's avatar
Guolin Ke committed
677
678
  }
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
679
  *out = ret.release();
680
  API_END();
681
682
}

Guolin Ke's avatar
Guolin Ke committed
683
int LGBM_DatasetCreateFromCSR(const void* indptr,
684
685
686
687
688
689
690
691
692
693
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
                              int64_t num_col,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
694
  API_BEGIN();
695
696
697
698
699
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
700
701
  auto param = Config::Str2Map(parameters);
  Config config;
702
703
704
705
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
706
  std::unique_ptr<Dataset> ret;
707
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
708
709
710
  int32_t nrow = static_cast<int32_t>(nindptr - 1);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
711
712
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
713
    auto sample_indices = rand.Sample(nrow, sample_cnt);
714
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
715
716
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
717
718
719
720
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      auto row = get_row_fun(static_cast<int>(idx));
      for (std::pair<int, double>& inner_data : row) {
Guolin Ke's avatar
Guolin Ke committed
721
        CHECK(inner_data.first < num_col);
Guolin Ke's avatar
Guolin Ke committed
722
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
723
724
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
725
726
727
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
728
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
729
730
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
731
                                            static_cast<int>(num_col),
732
733
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
734
  } else {
735
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
736
    ret->CreateValid(
737
      reinterpret_cast<const Dataset*>(reference));
738
  }
739
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
740
  #pragma omp parallel for schedule(static)
741
  for (int i = 0; i < nindptr - 1; ++i) {
742
    OMP_LOOP_EX_BEGIN();
743
744
745
    const int tid = omp_get_thread_num();
    auto one_row = get_row_fun(i);
    ret->PushOneRow(tid, i, one_row);
746
    OMP_LOOP_EX_END();
747
  }
748
  OMP_THROW_EX();
749
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
750
  *out = ret.release();
751
  API_END();
752
753
}

754
int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
755
756
757
758
759
                                  int num_rows,
                                  int64_t num_col,
                                  const char* parameters,
                                  const DatasetHandle reference,
                                  DatasetHandle* out) {
760
  API_BEGIN();
761
762
763
764
765
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
  auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);
  auto param = Config::Str2Map(parameters);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  std::unique_ptr<Dataset> ret;
  int32_t nrow = num_rows;
  if (reference == nullptr) {
    // sample data first
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
    auto sample_indices = rand.Sample(nrow, sample_cnt);
    sample_cnt = static_cast<int>(sample_indices.size());
    std::vector<std::vector<double>> sample_values(num_col);
    std::vector<std::vector<int>> sample_idx(num_col);
    // local buffer to re-use memory
    std::vector<std::pair<int, double>> buffer;
    for (size_t i = 0; i < sample_indices.size(); ++i) {
      auto idx = sample_indices[i];
      get_row_fun(static_cast<int>(idx), buffer);
      for (std::pair<int, double>& inner_data : buffer) {
        CHECK(inner_data.first < num_col);
        if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
          sample_values[inner_data.first].emplace_back(inner_data.second);
          sample_idx[inner_data.first].emplace_back(static_cast<int>(i));
        }
      }
    }
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
797
798
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
799
                                            static_cast<int>(num_col),
800
801
802
803
804
805
806
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
  } else {
    ret.reset(new Dataset(nrow));
    ret->CreateValid(
      reinterpret_cast<const Dataset*>(reference));
  }
807

808
809
810
811
812
813
  OMP_INIT_EX();
  std::vector<std::pair<int, double>> threadBuffer;
  #pragma omp parallel for schedule(static) private(threadBuffer)
  for (int i = 0; i < num_rows; ++i) {
    OMP_LOOP_EX_BEGIN();
    {
814
815
816
      const int tid = omp_get_thread_num();
      get_row_fun(i, threadBuffer);
      ret->PushOneRow(tid, i, threadBuffer);
817
818
819
820
821
822
823
824
825
    }
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  ret->FinishLoad();
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
826
int LGBM_DatasetCreateFromCSC(const void* col_ptr,
827
828
829
830
831
832
833
834
835
836
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              const char* parameters,
                              const DatasetHandle reference,
                              DatasetHandle* out) {
837
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
838
839
  auto param = Config::Str2Map(parameters);
  Config config;
840
841
842
843
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
844
  std::unique_ptr<Dataset> ret;
Guolin Ke's avatar
Guolin Ke committed
845
846
847
  int32_t nrow = static_cast<int32_t>(num_row);
  if (reference == nullptr) {
    // sample data first
Guolin Ke's avatar
Guolin Ke committed
848
849
    Random rand(config.data_random_seed);
    int sample_cnt = static_cast<int>(nrow < config.bin_construct_sample_cnt ? nrow : config.bin_construct_sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
850
    auto sample_indices = rand.Sample(nrow, sample_cnt);
851
    sample_cnt = static_cast<int>(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
852
    std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
853
    std::vector<std::vector<int>> sample_idx(ncol_ptr - 1);
854
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
855
    #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
856
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
857
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
858
859
860
      CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
      for (int j = 0; j < sample_cnt; j++) {
        auto val = col_it.Get(sample_indices[j]);
Guolin Ke's avatar
Guolin Ke committed
861
        if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
862
863
          sample_values[i].emplace_back(val);
          sample_idx[i].emplace_back(j);
Guolin Ke's avatar
Guolin Ke committed
864
865
        }
      }
866
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
867
    }
868
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
869
    DatasetLoader loader(config, nullptr, 1, nullptr);
Guolin Ke's avatar
Guolin Ke committed
870
871
    ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
                                            Common::Vector2Ptr<int>(&sample_idx).data(),
872
873
874
                                            static_cast<int>(sample_values.size()),
                                            Common::VectorSize<double>(sample_values).data(),
                                            sample_cnt, nrow));
Guolin Ke's avatar
Guolin Ke committed
875
  } else {
876
    ret.reset(new Dataset(nrow));
Guolin Ke's avatar
Guolin Ke committed
877
    ret->CreateValid(
878
      reinterpret_cast<const Dataset*>(reference));
Guolin Ke's avatar
Guolin Ke committed
879
  }
880
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
881
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
882
  for (int i = 0; i < ncol_ptr - 1; ++i) {
883
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
884
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
885
    int feature_idx = ret->InnerFeatureIndex(i);
Guolin Ke's avatar
Guolin Ke committed
886
    if (feature_idx < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
887
888
    int group = ret->Feature2Group(feature_idx);
    int sub_feature = ret->Feture2SubFeature(feature_idx);
Guolin Ke's avatar
Guolin Ke committed
889
890
891
892
893
894
895
    CSC_RowIterator col_it(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, i);
    int row_idx = 0;
    while (row_idx < nrow) {
      auto pair = col_it.NextNonZero();
      row_idx = pair.first;
      // no more data
      if (row_idx < 0) { break; }
Guolin Ke's avatar
Guolin Ke committed
896
      ret->PushOneData(tid, row_idx, group, sub_feature, pair.second);
Guolin Ke's avatar
Guolin Ke committed
897
    }
898
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
899
  }
900
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
901
  ret->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
902
  *out = ret.release();
903
  API_END();
Guolin Ke's avatar
Guolin Ke committed
904
905
}

Guolin Ke's avatar
Guolin Ke committed
906
int LGBM_DatasetGetSubset(
907
  const DatasetHandle handle,
wxchan's avatar
wxchan committed
908
909
910
  const int32_t* used_row_indices,
  int32_t num_used_row_indices,
  const char* parameters,
Guolin Ke's avatar
typo  
Guolin Ke committed
911
  DatasetHandle* out) {
wxchan's avatar
wxchan committed
912
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
913
914
  auto param = Config::Str2Map(parameters);
  Config config;
915
916
917
918
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
919
  auto full_dataset = reinterpret_cast<const Dataset*>(handle);
Guolin Ke's avatar
Guolin Ke committed
920
  CHECK(num_used_row_indices > 0);
921
922
923
  const int32_t lower = 0;
  const int32_t upper = full_dataset->num_data() - 1;
  Common::CheckElementsIntervalClosed(used_row_indices, lower, upper, num_used_row_indices, "Used indices of subset");
924
925
926
  if (!std::is_sorted(used_row_indices, used_row_indices + num_used_row_indices)) {
    Log::Fatal("used_row_indices should be sorted in Subset");
  }
Guolin Ke's avatar
Guolin Ke committed
927
  auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_row_indices));
928
  ret->CopyFeatureMapperFrom(full_dataset);
Guolin Ke's avatar
Guolin Ke committed
929
  ret->CopySubset(full_dataset, used_row_indices, num_used_row_indices, true);
wxchan's avatar
wxchan committed
930
931
932
933
  *out = ret.release();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
934
int LGBM_DatasetSetFeatureNames(
Guolin Ke's avatar
typo  
Guolin Ke committed
935
  DatasetHandle handle,
Guolin Ke's avatar
Guolin Ke committed
936
  const char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
937
  int num_feature_names) {
Guolin Ke's avatar
Guolin Ke committed
938
939
940
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  std::vector<std::string> feature_names_str;
Guolin Ke's avatar
Guolin Ke committed
941
  for (int i = 0; i < num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
942
943
944
945
946
947
    feature_names_str.emplace_back(feature_names[i]);
  }
  dataset->set_feature_names(feature_names_str);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
948
int LGBM_DatasetGetFeatureNames(
949
950
  DatasetHandle handle,
  char** feature_names,
Guolin Ke's avatar
Guolin Ke committed
951
  int* num_feature_names) {
952
953
954
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  auto inside_feature_name = dataset->feature_names();
Guolin Ke's avatar
Guolin Ke committed
955
956
  *num_feature_names = static_cast<int>(inside_feature_name.size());
  for (int i = 0; i < *num_feature_names; ++i) {
Guolin Ke's avatar
Guolin Ke committed
957
    std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
958
959
960
961
  }
  API_END();
}

962
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
963
int LGBM_DatasetFree(DatasetHandle handle) {
964
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
965
  delete reinterpret_cast<Dataset*>(handle);
966
  API_END();
967
968
}

Guolin Ke's avatar
Guolin Ke committed
969
int LGBM_DatasetSaveBinary(DatasetHandle handle,
970
                           const char* filename) {
971
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
972
973
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->SaveBinaryFile(filename);
974
  API_END();
975
976
}

977
978
979
980
981
982
983
984
int LGBM_DatasetDumpText(DatasetHandle handle,
                         const char* filename) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->DumpTextFile(filename);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
985
int LGBM_DatasetSetField(DatasetHandle handle,
986
987
988
989
                         const char* field_name,
                         const void* field_data,
                         int num_element,
                         int type) {
990
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
991
  auto dataset = reinterpret_cast<Dataset*>(handle);
992
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
993
  if (type == C_API_DTYPE_FLOAT32) {
Guolin Ke's avatar
Guolin Ke committed
994
    is_success = dataset->SetFloatField(field_name, reinterpret_cast<const float*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
995
  } else if (type == C_API_DTYPE_INT32) {
Guolin Ke's avatar
Guolin Ke committed
996
    is_success = dataset->SetIntField(field_name, reinterpret_cast<const int*>(field_data), static_cast<int32_t>(num_element));
Guolin Ke's avatar
Guolin Ke committed
997
998
  } else if (type == C_API_DTYPE_FLOAT64) {
    is_success = dataset->SetDoubleField(field_name, reinterpret_cast<const double*>(field_data), static_cast<int32_t>(num_element));
999
  }
1000
  if (!is_success) { throw std::runtime_error("Input data type error or field not found"); }
1001
  API_END();
1002
1003
}

Guolin Ke's avatar
Guolin Ke committed
1004
int LGBM_DatasetGetField(DatasetHandle handle,
1005
1006
1007
1008
                         const char* field_name,
                         int* out_len,
                         const void** out_ptr,
                         int* out_type) {
1009
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1010
  auto dataset = reinterpret_cast<Dataset*>(handle);
1011
  bool is_success = false;
Guolin Ke's avatar
Guolin Ke committed
1012
  if (dataset->GetFloatField(field_name, out_len, reinterpret_cast<const float**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1013
    *out_type = C_API_DTYPE_FLOAT32;
1014
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1015
  } else if (dataset->GetIntField(field_name, out_len, reinterpret_cast<const int**>(out_ptr))) {
Guolin Ke's avatar
Guolin Ke committed
1016
    *out_type = C_API_DTYPE_INT32;
1017
    is_success = true;
Guolin Ke's avatar
Guolin Ke committed
1018
1019
1020
  } else if (dataset->GetDoubleField(field_name, out_len, reinterpret_cast<const double**>(out_ptr))) {
    *out_type = C_API_DTYPE_FLOAT64;
    is_success = true;
1021
  } else if (dataset->GetInt8Field(field_name, out_len, reinterpret_cast<const int8_t**>(out_ptr))) {
1022
1023
    *out_type = C_API_DTYPE_INT8;
    is_success = true;
1024
  }
1025
  if (!is_success) { throw std::runtime_error("Field not found"); }
wxchan's avatar
wxchan committed
1026
  if (*out_ptr == nullptr) { *out_len = 0; }
1027
  API_END();
1028
1029
}

1030
1031
1032
1033
1034
1035
1036
int LGBM_DatasetUpdateParam(DatasetHandle handle, const char* parameters) {
  API_BEGIN();
  auto dataset = reinterpret_cast<Dataset*>(handle);
  dataset->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1037
int LGBM_DatasetGetNumData(DatasetHandle handle,
1038
                           int* out) {
1039
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1040
1041
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_data();
1042
  API_END();
1043
1044
}

Guolin Ke's avatar
Guolin Ke committed
1045
int LGBM_DatasetGetNumFeature(DatasetHandle handle,
1046
                              int* out) {
1047
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1048
1049
  auto dataset = reinterpret_cast<Dataset*>(handle);
  *out = dataset->num_total_features();
1050
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1051
}
1052

1053
1054
1055
1056
1057
1058
1059
1060
1061
int LGBM_DatasetAddFeaturesFrom(DatasetHandle target,
                                DatasetHandle source) {
  API_BEGIN();
  auto target_d = reinterpret_cast<Dataset*>(target);
  auto source_d = reinterpret_cast<Dataset*>(source);
  target_d->addFeaturesFrom(source_d);
  API_END();
}

1062
1063
// ---- start of booster

Guolin Ke's avatar
Guolin Ke committed
1064
int LGBM_BoosterCreate(const DatasetHandle train_data,
1065
1066
                       const char* parameters,
                       BoosterHandle* out) {
1067
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1068
  const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
wxchan's avatar
wxchan committed
1069
1070
  auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
  *out = ret.release();
1071
  API_END();
1072
1073
}

Guolin Ke's avatar
Guolin Ke committed
1074
int LGBM_BoosterCreateFromModelfile(
1075
  const char* filename,
Guolin Ke's avatar
Guolin Ke committed
1076
  int* out_num_iterations,
1077
  BoosterHandle* out) {
1078
  API_BEGIN();
wxchan's avatar
wxchan committed
1079
  auto ret = std::unique_ptr<Booster>(new Booster(filename));
Guolin Ke's avatar
Guolin Ke committed
1080
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
wxchan's avatar
wxchan committed
1081
  *out = ret.release();
1082
  API_END();
1083
1084
}

Guolin Ke's avatar
Guolin Ke committed
1085
int LGBM_BoosterLoadModelFromString(
1086
1087
1088
1089
  const char* model_str,
  int* out_num_iterations,
  BoosterHandle* out) {
  API_BEGIN();
wxchan's avatar
wxchan committed
1090
  auto ret = std::unique_ptr<Booster>(new Booster(nullptr));
1091
1092
1093
1094
1095
1096
  ret->LoadModelFromString(model_str);
  *out_num_iterations = ret->GetBoosting()->GetCurrentIteration();
  *out = ret.release();
  API_END();
}

1097
#pragma warning(disable : 4702)
Guolin Ke's avatar
Guolin Ke committed
1098
int LGBM_BoosterFree(BoosterHandle handle) {
1099
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1100
  delete reinterpret_cast<Booster*>(handle);
1101
  API_END();
1102
1103
}

1104
int LGBM_BoosterShuffleModels(BoosterHandle handle, int start_iter, int end_iter) {
1105
1106
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1107
  ref_booster->ShuffleModels(start_iter, end_iter);
1108
1109
1110
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1111
int LGBM_BoosterMerge(BoosterHandle handle,
1112
                      BoosterHandle other_handle) {
wxchan's avatar
wxchan committed
1113
1114
1115
1116
1117
1118
1119
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
  ref_booster->MergeFrom(ref_other_booster);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1120
int LGBM_BoosterAddValidData(BoosterHandle handle,
1121
                             const DatasetHandle valid_data) {
wxchan's avatar
wxchan committed
1122
1123
1124
1125
1126
1127
1128
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
  ref_booster->AddValidData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1129
int LGBM_BoosterResetTrainingData(BoosterHandle handle,
1130
                                  const DatasetHandle train_data) {
wxchan's avatar
wxchan committed
1131
1132
1133
1134
1135
1136
1137
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
  ref_booster->ResetTrainingData(p_dataset);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1138
int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
wxchan's avatar
wxchan committed
1139
1140
1141
1142
1143
1144
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->ResetConfig(parameters);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1145
int LGBM_BoosterGetNumClasses(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1146
1147
1148
1149
1150
1151
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->NumberOfClasses();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1152
1153
1154
1155
1156
1157
1158
int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t nrow, int32_t ncol) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->Refit(leaf_preds, nrow, ncol);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1159
int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
1160
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1161
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1162
1163
1164
1165
1166
  if (ref_booster->TrainOneIter()) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1167
  API_END();
1168
1169
}

Guolin Ke's avatar
Guolin Ke committed
1170
int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle,
1171
1172
1173
                                    const float* grad,
                                    const float* hess,
                                    int* is_finished) {
1174
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1175
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1176
  #ifdef SCORE_T_USE_DOUBLE
1177
  Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled");
1178
  #else
1179
1180
1181
1182
1183
  if (ref_booster->TrainOneIter(grad, hess)) {
    *is_finished = 1;
  } else {
    *is_finished = 0;
  }
1184
  #endif
1185
  API_END();
1186
1187
}

Guolin Ke's avatar
Guolin Ke committed
1188
int LGBM_BoosterRollbackOneIter(BoosterHandle handle) {
wxchan's avatar
wxchan committed
1189
1190
1191
1192
1193
1194
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  ref_booster->RollbackOneIter();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1195
int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
wxchan's avatar
wxchan committed
1196
1197
1198
1199
1200
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_iteration = ref_booster->GetBoosting()->GetCurrentIteration();
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1201

1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
  API_END();
}

int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1216
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1217
1218
1219
1220
1221
1222
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalCounts();
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1223
int LGBM_BoosterGetEvalNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1224
1225
1226
1227
1228
1229
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetEvalNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1230
int LGBM_BoosterGetFeatureNames(BoosterHandle handle, int* out_len, char** out_strs) {
wxchan's avatar
wxchan committed
1231
1232
1233
1234
1235
1236
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetFeatureNames(out_strs);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1237
int LGBM_BoosterGetNumFeature(BoosterHandle handle, int* out_len) {
wxchan's avatar
wxchan committed
1238
1239
1240
1241
1242
1243
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  *out_len = ref_booster->GetBoosting()->MaxFeatureIdx() + 1;
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1244
int LGBM_BoosterGetEval(BoosterHandle handle,
1245
1246
1247
                        int data_idx,
                        int* out_len,
                        double* out_results) {
1248
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1249
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1250
  auto boosting = ref_booster->GetBoosting();
wxchan's avatar
wxchan committed
1251
  auto result_buf = boosting->GetEvalAt(data_idx);
Guolin Ke's avatar
Guolin Ke committed
1252
  *out_len = static_cast<int>(result_buf.size());
1253
  for (size_t i = 0; i < result_buf.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1254
    (out_results)[i] = static_cast<double>(result_buf[i]);
1255
  }
1256
  API_END();
1257
1258
}

Guolin Ke's avatar
Guolin Ke committed
1259
int LGBM_BoosterGetNumPredict(BoosterHandle handle,
1260
1261
                              int data_idx,
                              int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1262
1263
1264
1265
1266
1267
  API_BEGIN();
  auto boosting = reinterpret_cast<Booster*>(handle)->GetBoosting();
  *out_len = boosting->GetNumPredictAt(data_idx);
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1268
int LGBM_BoosterGetPredict(BoosterHandle handle,
1269
1270
1271
                           int data_idx,
                           int64_t* out_len,
                           double* out_result) {
1272
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1273
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1274
  ref_booster->GetPredictAt(data_idx, out_result, out_len);
1275
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1276
1277
}

Guolin Ke's avatar
Guolin Ke committed
1278
int LGBM_BoosterPredictForFile(BoosterHandle handle,
1279
1280
1281
1282
                               const char* data_filename,
                               int data_has_header,
                               int predict_type,
                               int num_iteration,
1283
                               const char* parameter,
1284
                               const char* result_filename) {
1285
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1286
1287
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1288
1289
1290
1291
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1292
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
cbecker's avatar
cbecker committed
1293
  ref_booster->Predict(num_iteration, predict_type, data_filename, data_has_header,
Guolin Ke's avatar
Guolin Ke committed
1294
                       config, result_filename);
1295
  API_END();
1296
1297
}

Guolin Ke's avatar
Guolin Ke committed
1298
int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
1299
1300
1301
1302
                               int num_row,
                               int predict_type,
                               int num_iteration,
                               int64_t* out_len) {
Guolin Ke's avatar
Guolin Ke committed
1303
1304
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1305
1306
  *out_len = static_cast<int64_t>(num_row) * ref_booster->GetBoosting()->NumPredictOneRow(
    num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB);
Guolin Ke's avatar
Guolin Ke committed
1307
1308
1309
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1310
int LGBM_BoosterPredictForCSR(BoosterHandle handle,
1311
1312
1313
1314
1315
1316
1317
                              const void* indptr,
                              int indptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t nindptr,
                              int64_t nelem,
1318
                              int64_t num_col,
1319
1320
                              int predict_type,
                              int num_iteration,
1321
                              const char* parameter,
1322
1323
                              int64_t* out_len,
                              double* out_result) {
1324
  API_BEGIN();
1325
1326
1327
1328
1329
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
Guolin Ke's avatar
Guolin Ke committed
1330
1331
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1332
1333
1334
1335
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1336
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1337
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
Guolin Ke's avatar
Guolin Ke committed
1338
  int nrow = static_cast<int>(nindptr - 1);
1339
  ref_booster->Predict(num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1340
                       config, out_result, out_len);
1341
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1342
}
1343

1344
int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
1345
1346
1347
1348
1349
1350
1351
                                       const void* indptr,
                                       int indptr_type,
                                       const int32_t* indices,
                                       const void* data,
                                       int data_type,
                                       int64_t nindptr,
                                       int64_t nelem,
1352
                                       int64_t num_col,
1353
1354
1355
1356
1357
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1358
  API_BEGIN();
1359
1360
1361
1362
1363
  if (num_col <= 0) {
    Log::Fatal("The number of columns should be greater than zero.");
  } else if (num_col >= INT32_MAX) {
    Log::Fatal("The number of columns should be smaller than INT32_MAX.");
  }
1364
1365
1366
1367
1368
1369
1370
1371
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
1372
  ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
1373
1374
1375
1376
  API_END();
}


Guolin Ke's avatar
Guolin Ke committed
1377
int LGBM_BoosterPredictForCSC(BoosterHandle handle,
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
                              const void* col_ptr,
                              int col_ptr_type,
                              const int32_t* indices,
                              const void* data,
                              int data_type,
                              int64_t ncol_ptr,
                              int64_t nelem,
                              int64_t num_row,
                              int predict_type,
                              int num_iteration,
1388
                              const char* parameter,
1389
1390
                              int64_t* out_len,
                              double* out_result) {
Guolin Ke's avatar
Guolin Ke committed
1391
1392
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1393
1394
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
1405
  int ncol = static_cast<int>(ncol_ptr - 1);
Guolin Ke's avatar
Guolin Ke committed
1406
1407
1408
1409
1410
  std::vector<std::vector<CSC_RowIterator>> iterators(num_threads, std::vector<CSC_RowIterator>());
  for (int i = 0; i < num_threads; ++i) {
    for (int j = 0; j < ncol; ++j) {
      iterators[i].emplace_back(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, j);
    }
Guolin Ke's avatar
Guolin Ke committed
1411
1412
  }
  std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun =
1413
    [&iterators, ncol] (int i) {
Guolin Ke's avatar
Guolin Ke committed
1414
    std::vector<std::pair<int, double>> one_row;
Guolin Ke's avatar
Guolin Ke committed
1415
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
1416
    for (int j = 0; j < ncol; ++j) {
Guolin Ke's avatar
Guolin Ke committed
1417
      auto val = iterators[tid][j].Get(i);
Guolin Ke's avatar
Guolin Ke committed
1418
      if (std::fabs(val) > kZeroThreshold || std::isnan(val)) {
Guolin Ke's avatar
Guolin Ke committed
1419
        one_row.emplace_back(j, val);
Guolin Ke's avatar
Guolin Ke committed
1420
1421
      }
    }
Guolin Ke's avatar
Guolin Ke committed
1422
1423
    return one_row;
  };
1424
  ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config,
cbecker's avatar
cbecker committed
1425
                       out_result, out_len);
Guolin Ke's avatar
Guolin Ke committed
1426
1427
1428
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1429
int LGBM_BoosterPredictForMat(BoosterHandle handle,
1430
1431
1432
1433
1434
1435
1436
                              const void* data,
                              int data_type,
                              int32_t nrow,
                              int32_t ncol,
                              int is_row_major,
                              int predict_type,
                              int num_iteration,
1437
                              const char* parameter,
1438
1439
                              int64_t* out_len,
                              double* out_result) {
1440
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1441
1442
  auto param = Config::Str2Map(parameter);
  Config config;
Guolin Ke's avatar
Guolin Ke committed
1443
1444
1445
1446
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
Guolin Ke's avatar
Guolin Ke committed
1447
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1448
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
1449
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun,
Guolin Ke's avatar
Guolin Ke committed
1450
                       config, out_result, out_len);
1451
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1452
}
1453

1454
int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
1455
1456
1457
1458
1459
1460
1461
1462
1463
                                       const void* data,
                                       int data_type,
                                       int32_t ncol,
                                       int is_row_major,
                                       int predict_type,
                                       int num_iteration,
                                       const char* parameter,
                                       int64_t* out_len,
                                       double* out_result) {
1464
1465
1466
1467
1468
1469
1470
1471
1472
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
1473
  ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
1474
1475
1476
1477
  API_END();
}


1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
int LGBM_BoosterPredictForMats(BoosterHandle handle,
                               const void** data,
                               int data_type,
                               int32_t nrow,
                               int32_t ncol,
                               int predict_type,
                               int num_iteration,
                               const char* parameter,
                               int64_t* out_len,
                               double* out_result) {
  API_BEGIN();
  auto param = Config::Str2Map(parameter);
  Config config;
  config.Set(param);
  if (config.num_threads > 0) {
    omp_set_num_threads(config.num_threads);
  }
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
1497
  ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
1498
1499
1500
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1501
int LGBM_BoosterSaveModel(BoosterHandle handle,
1502
                          int start_iteration,
1503
1504
                          int num_iteration,
                          const char* filename) {
1505
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1506
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1507
  ref_booster->SaveModelToFile(start_iteration, num_iteration, filename);
wxchan's avatar
wxchan committed
1508
1509
1510
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1511
int LGBM_BoosterSaveModelToString(BoosterHandle handle,
1512
                                  int start_iteration,
1513
                                  int num_iteration,
1514
                                  int64_t buffer_len,
1515
                                  int64_t* out_len,
1516
                                  char* out_str) {
1517
1518
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1519
  std::string model = ref_booster->SaveModelToString(start_iteration, num_iteration);
1520
  *out_len = static_cast<int64_t>(model.size()) + 1;
1521
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1522
    std::memcpy(out_str, model.c_str(), *out_len);
1523
1524
1525
1526
  }
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1527
int LGBM_BoosterDumpModel(BoosterHandle handle,
1528
                          int start_iteration,
1529
                          int num_iteration,
1530
1531
                          int64_t buffer_len,
                          int64_t* out_len,
1532
                          char* out_str) {
wxchan's avatar
wxchan committed
1533
1534
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
1535
  std::string model = ref_booster->DumpModel(start_iteration, num_iteration);
1536
  *out_len = static_cast<int64_t>(model.size()) + 1;
wxchan's avatar
wxchan committed
1537
  if (*out_len <= buffer_len) {
Guolin Ke's avatar
Guolin Ke committed
1538
    std::memcpy(out_str, model.c_str(), *out_len);
wxchan's avatar
wxchan committed
1539
  }
1540
  API_END();
Guolin Ke's avatar
Guolin Ke committed
1541
}
1542

Guolin Ke's avatar
Guolin Ke committed
1543
int LGBM_BoosterGetLeafValue(BoosterHandle handle,
1544
1545
1546
                             int tree_idx,
                             int leaf_idx,
                             double* out_val) {
Guolin Ke's avatar
Guolin Ke committed
1547
1548
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1549
  *out_val = static_cast<double>(ref_booster->GetLeafValue(tree_idx, leaf_idx));
Guolin Ke's avatar
Guolin Ke committed
1550
1551
1552
  API_END();
}

Guolin Ke's avatar
Guolin Ke committed
1553
int LGBM_BoosterSetLeafValue(BoosterHandle handle,
1554
1555
1556
                             int tree_idx,
                             int leaf_idx,
                             double val) {
Guolin Ke's avatar
Guolin Ke committed
1557
1558
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Guolin Ke's avatar
Guolin Ke committed
1559
  ref_booster->SetLeafValue(tree_idx, leaf_idx, val);
Guolin Ke's avatar
Guolin Ke committed
1560
1561
1562
  API_END();
}

1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
int LGBM_BoosterFeatureImportance(BoosterHandle handle,
                                  int num_iteration,
                                  int importance_type,
                                  double* out_results) {
  API_BEGIN();
  Booster* ref_booster = reinterpret_cast<Booster*>(handle);
  std::vector<double> feature_importances = ref_booster->FeatureImportance(num_iteration, importance_type);
  for (size_t i = 0; i < feature_importances.size(); ++i) {
    (out_results)[i] = feature_importances[i];
  }
  API_END();
}

1576
1577
1578
1579
1580
int LGBM_NetworkInit(const char* machines,
                     int local_listen_port,
                     int listen_time_out,
                     int num_machines) {
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1581
  Config config;
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
  config.machines = Common::RemoveQuotationSymbol(std::string(machines));
  config.local_listen_port = local_listen_port;
  config.num_machines = num_machines;
  config.time_out = listen_time_out;
  if (num_machines > 1) {
    Network::Init(config);
  }
  API_END();
}

int LGBM_NetworkFree() {
  API_BEGIN();
  Network::Dispose();
  API_END();
}

1598
1599
1600
int LGBM_NetworkInitWithFunctions(int num_machines, int rank,
                                  void* reduce_scatter_ext_fun,
                                  void* allgather_ext_fun) {
ww's avatar
ww committed
1601
  API_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1602
  if (num_machines > 1) {
1603
    Network::Init(num_machines, rank, (ReduceScatterFunction)reduce_scatter_ext_fun, (AllgatherFunction)allgather_ext_fun);
ww's avatar
ww committed
1604
1605
1606
  }
  API_END();
}
Guolin Ke's avatar
Guolin Ke committed
1607

Guolin Ke's avatar
Guolin Ke committed
1608
// ---- start of some help functions
1609
1610
1611

std::function<std::vector<double>(int row_idx)>
RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1612
  if (data_type == C_API_DTYPE_FLOAT32) {
1613
1614
    const float* data_ptr = reinterpret_cast<const float*>(data);
    if (is_row_major) {
1615
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1616
        std::vector<double> ret(num_col);
1617
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1618
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1619
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1620
1621
1622
1623
        }
        return ret;
      };
    } else {
1624
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1625
        std::vector<double> ret(num_col);
1626
        for (int i = 0; i < num_col; ++i) {
1627
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1628
1629
1630
1631
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1632
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1633
1634
    const double* data_ptr = reinterpret_cast<const double*>(data);
    if (is_row_major) {
1635
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1636
        std::vector<double> ret(num_col);
1637
        auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
1638
        for (int i = 0; i < num_col; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1639
          ret[i] = static_cast<double>(*(tmp_ptr + i));
1640
1641
1642
1643
        }
        return ret;
      };
    } else {
1644
      return [=] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1645
        std::vector<double> ret(num_col);
1646
        for (int i = 0; i < num_col; ++i) {
1647
          ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
1648
1649
1650
1651
1652
        }
        return ret;
      };
    }
  }
1653
  throw std::runtime_error("Unknown data type in RowFunctionFromDenseMatric");
1654
1655
1656
1657
}

std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_type, int is_row_major) {
Guolin Ke's avatar
Guolin Ke committed
1658
1659
  auto inner_function = RowFunctionFromDenseMatric(data, num_row, num_col, data_type, is_row_major);
  if (inner_function != nullptr) {
1660
    return [inner_function] (int row_idx) {
Guolin Ke's avatar
Guolin Ke committed
1661
1662
1663
      auto raw_values = inner_function(row_idx);
      std::vector<std::pair<int, double>> ret;
      for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1664
        if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
Guolin Ke's avatar
Guolin Ke committed
1665
          ret.emplace_back(i, raw_values[i]);
1666
        }
Guolin Ke's avatar
Guolin Ke committed
1667
1668
1669
      }
      return ret;
    };
1670
  }
Guolin Ke's avatar
Guolin Ke committed
1671
  return nullptr;
1672
1673
}

1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
// data is array of pointers to individual rows
std::function<std::vector<std::pair<int, double>>(int row_idx)>
RowPairFunctionFromDenseRows(const void** data, int num_col, int data_type) {
  return [=](int row_idx) {
    auto inner_function = RowFunctionFromDenseMatric(data[row_idx], 1, num_col, data_type, /* is_row_major */ true);
    auto raw_values = inner_function(0);
    std::vector<std::pair<int, double>> ret;
    for (int i = 0; i < static_cast<int>(raw_values.size()); ++i) {
      if (std::fabs(raw_values[i]) > kZeroThreshold || std::isnan(raw_values[i])) {
        ret.emplace_back(i, raw_values[i]);
      }
    }
    return ret;
  };
}

1690
std::function<std::vector<std::pair<int, double>>(int idx)>
1691
RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, const void* data, int data_type, int64_t , int64_t ) {
Guolin Ke's avatar
Guolin Ke committed
1692
  if (data_type == C_API_DTYPE_FLOAT32) {
1693
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1694
    if (indptr_type == C_API_DTYPE_INT32) {
1695
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1696
      return [=] (int idx) {
1697
1698
1699
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1700
1701
1702
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1703
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1704
          ret.emplace_back(indices[i], data_ptr[i]);
1705
1706
1707
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1708
    } else if (indptr_type == C_API_DTYPE_INT64) {
1709
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1710
      return [=] (int idx) {
1711
1712
1713
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1714
1715
1716
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1717
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1718
          ret.emplace_back(indices[i], data_ptr[i]);
1719
1720
1721
1722
        }
        return ret;
      };
    }
Guolin Ke's avatar
Guolin Ke committed
1723
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1724
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1725
    if (indptr_type == C_API_DTYPE_INT32) {
1726
      const int32_t* ptr_indptr = reinterpret_cast<const int32_t*>(indptr);
1727
      return [=] (int idx) {
1728
1729
1730
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1731
1732
1733
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1734
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1735
          ret.emplace_back(indices[i], data_ptr[i]);
1736
1737
1738
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1739
    } else if (indptr_type == C_API_DTYPE_INT64) {
1740
      const int64_t* ptr_indptr = reinterpret_cast<const int64_t*>(indptr);
1741
      return [=] (int idx) {
1742
1743
1744
        std::vector<std::pair<int, double>> ret;
        int64_t start = ptr_indptr[idx];
        int64_t end = ptr_indptr[idx + 1];
1745
1746
1747
        if (end - start > 0)  {
          ret.reserve(end - start);
        }
Guolin Ke's avatar
Guolin Ke committed
1748
        for (int64_t i = start; i < end; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1749
          ret.emplace_back(indices[i], data_ptr[i]);
1750
1751
1752
        }
        return ret;
      };
Guolin Ke's avatar
Guolin Ke committed
1753
1754
    }
  }
1755
  throw std::runtime_error("Unknown data type in RowFunctionFromCSR");
1756
1757
}

Guolin Ke's avatar
Guolin Ke committed
1758
std::function<std::pair<int, double>(int idx)>
1759
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1760
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
Guolin Ke's avatar
Guolin Ke committed
1761
  if (data_type == C_API_DTYPE_FLOAT32) {
1762
    const float* data_ptr = reinterpret_cast<const float*>(data);
Guolin Ke's avatar
Guolin Ke committed
1763
    if (col_ptr_type == C_API_DTYPE_INT32) {
1764
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1765
1766
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1767
1768
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1769
1770
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1771
        }
Guolin Ke's avatar
Guolin Ke committed
1772
1773
1774
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1775
      };
Guolin Ke's avatar
Guolin Ke committed
1776
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1777
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1778
1779
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1780
1781
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1782
1783
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1784
        }
Guolin Ke's avatar
Guolin Ke committed
1785
1786
1787
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1788
      };
Guolin Ke's avatar
Guolin Ke committed
1789
    }
Guolin Ke's avatar
Guolin Ke committed
1790
  } else if (data_type == C_API_DTYPE_FLOAT64) {
1791
    const double* data_ptr = reinterpret_cast<const double*>(data);
Guolin Ke's avatar
Guolin Ke committed
1792
    if (col_ptr_type == C_API_DTYPE_INT32) {
1793
      const int32_t* ptr_col_ptr = reinterpret_cast<const int32_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1794
1795
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1796
1797
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1798
1799
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1800
        }
Guolin Ke's avatar
Guolin Ke committed
1801
1802
1803
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1804
      };
Guolin Ke's avatar
Guolin Ke committed
1805
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
1806
      const int64_t* ptr_col_ptr = reinterpret_cast<const int64_t*>(col_ptr);
Guolin Ke's avatar
Guolin Ke committed
1807
1808
      int64_t start = ptr_col_ptr[col_idx];
      int64_t end = ptr_col_ptr[col_idx + 1];
1809
1810
      return [=] (int offset) {
        int64_t i = static_cast<int64_t>(start + offset);
Guolin Ke's avatar
Guolin Ke committed
1811
1812
        if (i >= end) {
          return std::make_pair(-1, 0.0);
1813
        }
Guolin Ke's avatar
Guolin Ke committed
1814
1815
1816
        int idx = static_cast<int>(indices[i]);
        double val = static_cast<double>(data_ptr[i]);
        return std::make_pair(idx, val);
1817
      };
Guolin Ke's avatar
Guolin Ke committed
1818
1819
    }
  }
1820
  throw std::runtime_error("Unknown data type in CSC matrix");
1821
1822
}

Guolin Ke's avatar
Guolin Ke committed
1823
CSC_RowIterator::CSC_RowIterator(const void* col_ptr, int col_ptr_type, const int32_t* indices,
1824
                                 const void* data, int data_type, int64_t ncol_ptr, int64_t nelem, int col_idx) {
Guolin Ke's avatar
Guolin Ke committed
1825
1826
1827
1828
1829
1830
1831
1832
1833
  iter_fun_ = IterateFunctionFromCSC(col_ptr, col_ptr_type, indices, data, data_type, ncol_ptr, nelem, col_idx);
}

double CSC_RowIterator::Get(int idx) {
  while (idx > cur_idx_ && !is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    if (ret.first < 0) {
      is_end_ = true;
      break;
1834
    }
Guolin Ke's avatar
Guolin Ke committed
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
    cur_idx_ = ret.first;
    cur_val_ = ret.second;
    ++nonzero_idx_;
  }
  if (idx == cur_idx_) {
    return cur_val_;
  } else {
    return 0.0f;
  }
}

std::pair<int, double> CSC_RowIterator::NextNonZero() {
  if (!is_end_) {
    auto ret = iter_fun_(nonzero_idx_);
    ++nonzero_idx_;
    if (ret.first < 0) {
      is_end_ = true;
1852
    }
Guolin Ke's avatar
Guolin Ke committed
1853
1854
1855
    return ret;
  } else {
    return std::make_pair(-1, 0.0);
1856
  }
Guolin Ke's avatar
Guolin Ke committed
1857
}